diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py
index c26b53faaecfb2abeb0448f409cb4a40011f360f..015f77a70a8cd541b8d7a2739f9444e113d187ff 100644
--- a/bob/learn/tensorflow/script/train.py
+++ b/bob/learn/tensorflow/script/train.py
@@ -13,7 +13,7 @@ Options:
-h --help Show this screen.
--iterations=<arg> [default: 1000]
--validation-interval=<arg> [default: 100]
- --output-dir=<arg> [default: ./logs/]
+ --output-dir=<arg> If the directory exists, will try to get the last checkpoint [default: ./logs/]
--pretrained-net=<arg>
"""
@@ -21,7 +21,8 @@ Options:
from docopt import docopt
import imp
import bob.learn.tensorflow
-
+import tensorflow as tf
+import os
def main():
args = docopt(__doc__, version='Train Neural Net')
@@ -35,6 +36,7 @@ def main():
if not args['--pretrained-net'] is None:
PRETRAINED_NET = str(args['--pretrained-net'])
+
config = imp.load_source('config', args['<configuration>'])
# One graph trainer
@@ -43,22 +45,31 @@ def main():
analizer=None,
temp_dir=OUTPUT_DIR)
- # Preparing the architecture
- input_pl = config.train_data_shuffler("data", from_queue=False)
- if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
- graph = dict()
- graph['left'] = config.architecture(input_pl['left'])
- graph['right'] = config.architecture(input_pl['right'], reuse=True)
-
- elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
- graph = dict()
- graph['anchor'] = config.architecture(input_pl['anchor'])
- graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
- graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
+
+ if os.path.exists(OUTPUT_DIR):
+ print("Directory already exists, trying to get the last checkpoint")
+ import ipdb; ipdb.set_trace();
+ trainer.create_network_from_file(OUTPUT_DIR)
+
else:
- graph = config.architecture(input_pl)
- trainer.create_network_from_scratch(graph, loss=config.loss,
- learning_rate=config.learning_rate,
- optimizer=config.optimizer)
- trainer.train(config.train_data_shuffler)
+ # Preparing the architecture
+ input_pl = config.train_data_shuffler("data", from_queue=False)
+ if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
+ graph = dict()
+ graph['left'] = config.architecture(input_pl['left'])
+ graph['right'] = config.architecture(input_pl['right'], reuse=True)
+
+ elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
+ graph = dict()
+ graph['anchor'] = config.architecture(input_pl['anchor'])
+ graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
+ graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
+ else:
+ graph = config.architecture(input_pl)
+
+ trainer.create_network_from_scratch(graph, loss=config.loss,
+ learning_rate=config.learning_rate,
+ optimizer=config.optimizer)
+ trainer.train()
+
diff --git a/bob/learn/tensorflow/test/data/train_scripts/softmax.py b/bob/learn/tensorflow/test/data/train_scripts/softmax.py
index 99bd3917e890bfec86240893fa76a4fe10f1af11..ae16cb4306dfae1fce901b716426dface4becba7 100644
--- a/bob/learn/tensorflow/test/data/train_scripts/softmax.py
+++ b/bob/learn/tensorflow/test/data/train_scripts/softmax.py
@@ -1,7 +1,7 @@
-from bob.learn.tensorflow.datashuffler import Memory
+from bob.learn.tensorflow.datashuffler import Memory, ScaleFactor
from bob.learn.tensorflow.network import Chopra
from bob.learn.tensorflow.trainers import Trainer, constant
-from bob.learn.tensorflow.loss import BaseLoss
+from bob.learn.tensorflow.loss import MeanSoftMaxLoss
from bob.learn.tensorflow.utils import load_mnist
import tensorflow as tf
import numpy
@@ -18,19 +18,23 @@ train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
train_data_shuffler = Memory(train_data, train_labels,
input_shape=INPUT_SHAPE,
- batch_size=BATCH_SIZE)
+ batch_size=BATCH_SIZE,
+ normalizer=ScaleFactor())
### ARCHITECTURE ###
-architecture = Chopra(seed=SEED, fc1_output=10, batch_norm=False)
+architecture = Chopra(seed=SEED, n_classes=10)
### LOSS ###
-loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
-
-### SOLVER ###
-optimizer = tf.train.GradientDescentOptimizer(0.001)
+loss = MeanSoftMaxLoss()
### LEARNING RATE ###
-learning_rate = constant(base_learning_rate=0.001)
+learning_rate = constant(base_learning_rate=0.01)
+
+
+### SOLVER ###
+optimizer = tf.train.GradientDescentOptimizer(learning_rate)
### Trainer ###
trainer = Trainer
+
+
diff --git a/bob/learn/tensorflow/test/test_train_script.py b/bob/learn/tensorflow/test/test_train_script.py
index e5b19a820e62f6c25994a0ddbbebf98004cf7449..86ea0f4d9938b0bdd67cdef0a93ad89ddbf0b02d 100644
--- a/bob/learn/tensorflow/test/test_train_script.py
+++ b/bob/learn/tensorflow/test/test_train_script.py
@@ -10,10 +10,10 @@ import shutil
def test_train_script_softmax():
directory = "./temp/train-script"
train_script = pkg_resources.resource_filename(__name__, './data/train_scripts/softmax.py')
- #train_script = './data/train_scripts/softmax.py'
+ train_script = './data/train_scripts/softmax.py'
- #from subprocess import call
- #call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
+ from subprocess import call
+ call(["./bin/train.py", "--iterations", "5", "--output-dir", directory, train_script])
#shutil.rmtree(directory)
assert True