Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
54941120
Commit
54941120
authored
Oct 31, 2016
by
Tiago de Freitas Pereira
Browse files
Pretrained model test
parent
136a436c
Changes
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/test/test_cnn_pretrained_model.py
0 → 100644
View file @
54941120
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import
numpy
import
bob.io.base
import
os
from
bob.learn.tensorflow.datashuffler
import
Memory
,
ImageAugmentation
from
bob.learn.tensorflow.loss
import
BaseLoss
from
bob.learn.tensorflow.trainers
import
Trainer
,
constant
from
bob.learn.tensorflow.util
import
load_mnist
import
tensorflow
as
tf
import
shutil
"""
Some unit tests that create networks on the fly and load variables
"""
batch_size
=
16
validation_batch_size
=
400
iterations
=
50
seed
=
10
from
test_cnn_scratch
import
scratch_network
,
validate_network
def
test_cnn_trainer_scratch
():
train_data
,
train_labels
,
validation_data
,
validation_labels
=
load_mnist
()
train_data
=
numpy
.
reshape
(
train_data
,
(
train_data
.
shape
[
0
],
28
,
28
,
1
))
# Creating datashufflers
data_augmentation
=
ImageAugmentation
()
train_data_shuffler
=
Memory
(
train_data
,
train_labels
,
input_shape
=
[
28
,
28
,
1
],
batch_size
=
batch_size
,
data_augmentation
=
data_augmentation
)
validation_data
=
numpy
.
reshape
(
validation_data
,
(
validation_data
.
shape
[
0
],
28
,
28
,
1
))
directory
=
"./temp/cnn"
directory2
=
"./temp/cnn2"
# Creating a random network
scratch
=
scratch_network
()
# Loss for the softmax
loss
=
BaseLoss
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
,
tf
.
reduce_mean
)
# One graph trainer
trainer
=
Trainer
(
architecture
=
scratch
,
loss
=
loss
,
iterations
=
iterations
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
constant
(
0.05
,
name
=
"lr"
),
temp_dir
=
directory
)
trainer
.
train
(
train_data_shuffler
)
accuracy
=
validate_network
(
validation_data
,
validation_labels
,
directory
)
assert
accuracy
>
85
del
scratch
del
loss
# Training the network using a pre trained model
loss2
=
BaseLoss
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
,
tf
.
reduce_mean
,
name
=
"loss2"
)
scratch
=
scratch_network
()
trainer2
=
Trainer
(
architecture
=
scratch
,
loss
=
loss2
,
iterations
=
iterations
,
analizer
=
None
,
prefetch
=
False
,
learning_rate
=
constant
(
0.05
,
name
=
"lr2"
),
temp_dir
=
directory2
,
model_from_file
=
os
.
path
.
join
(
directory
,
"model.hdf5"
))
trainer2
.
train
(
train_data_shuffler
)
accuracy
=
validate_network
(
validation_data
,
validation_labels
,
directory
)
assert
accuracy
>
90
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment