From 54941120e9132a0baea5f197f219e95c50b6580f Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Mon, 31 Oct 2016 09:58:18 +0100 Subject: [PATCH] Pretrained model test --- .../test/test_cnn_pretrained_model.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 bob/learn/tensorflow/test/test_cnn_pretrained_model.py diff --git a/bob/learn/tensorflow/test/test_cnn_pretrained_model.py b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py new file mode 100644 index 00000000..1b579e88 --- /dev/null +++ b/bob/learn/tensorflow/test/test_cnn_pretrained_model.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> +# @date: Thu 13 Oct 2016 13:35 CEST + +import numpy +import bob.io.base +import os +from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation +from bob.learn.tensorflow.loss import BaseLoss +from bob.learn.tensorflow.trainers import Trainer, constant +from bob.learn.tensorflow.util import load_mnist +import tensorflow as tf +import shutil + +""" +Some unit tests that create networks on the fly and load variables +""" + +batch_size = 16 +validation_batch_size = 400 +iterations = 50 +seed = 10 + +from test_cnn_scratch import scratch_network, validate_network + + +def test_cnn_trainer_scratch(): + train_data, train_labels, validation_data, validation_labels = load_mnist() + train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1)) + + # Creating datashufflers + data_augmentation = ImageAugmentation() + train_data_shuffler = Memory(train_data, train_labels, + input_shape=[28, 28, 1], + batch_size=batch_size, + data_augmentation=data_augmentation) + validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) + + directory = "./temp/cnn" + directory2 = "./temp/cnn2" + + # Creating a random network + scratch = scratch_network() + + # Loss for the softmax + loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) + + # One graph trainer + trainer = Trainer(architecture=scratch, + loss=loss, + iterations=iterations, + analizer=None, + prefetch=False, + learning_rate=constant(0.05, name="lr"), + temp_dir=directory) + trainer.train(train_data_shuffler) + + accuracy = validate_network(validation_data, validation_labels, directory) + assert accuracy > 85 + + del scratch + del loss + # Training the network using a pre trained model + loss2 = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean, name="loss2") + scratch = scratch_network() + trainer2 = Trainer(architecture=scratch, + loss=loss2, + iterations=iterations, + analizer=None, + prefetch=False, + learning_rate=constant(0.05, name="lr2"), + temp_dir=directory2, + model_from_file=os.path.join(directory, "model.hdf5")) + + trainer2.train(train_data_shuffler) + accuracy = validate_network(validation_data, validation_labels, directory) + assert accuracy > 90 -- GitLab