Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.pytorch
Commits
06d5b2df
Commit
06d5b2df
authored
Jan 22, 2019
by
Olegs NIKISINS
Browse files
Added dataset class, Conv-AE model, config to train on CelebA, and train script
parent
f979eb7d
Pipeline
#26220
failed with stage
in 7 minutes and 15 seconds
Changes
9
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/pytorch/architectures/ConvAutoencoder.py
0 → 100644
View file @
06d5b2df
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from
torch
import
nn
#==============================================================================
# Define the network:
class
ConvAutoencoder
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ConvAutoencoder
,
self
).
__init__
()
self
.
encoder
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
16
,
5
,
padding
=
2
),
nn
.
ReLU
(
True
),
nn
.
MaxPool2d
(
2
),
nn
.
Conv2d
(
16
,
16
,
5
,
padding
=
2
),
nn
.
ReLU
(
True
),
nn
.
MaxPool2d
(
2
),
nn
.
Conv2d
(
16
,
16
,
3
,
padding
=
2
),
nn
.
ReLU
(
True
),
nn
.
MaxPool2d
(
2
),
nn
.
Conv2d
(
16
,
16
,
3
,
padding
=
2
),
nn
.
ReLU
(
True
),
nn
.
MaxPool2d
(
2
))
self
.
decoder
=
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
16
,
16
,
3
,
stride
=
2
,
padding
=
1
),
nn
.
ReLU
(
True
),
nn
.
ConvTranspose2d
(
16
,
16
,
3
,
stride
=
2
,
padding
=
1
),
nn
.
ReLU
(
True
),
nn
.
ConvTranspose2d
(
16
,
16
,
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(
True
),
nn
.
ConvTranspose2d
(
16
,
3
,
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(
True
),
nn
.
ConvTranspose2d
(
3
,
3
,
2
,
stride
=
1
,
padding
=
1
),
nn
.
Tanh
())
def
forward
(
self
,
x
):
x
=
self
.
encoder
(
x
)
x
=
self
.
decoder
(
x
)
return
x
bob/learn/pytorch/architectures/__init__.py
View file @
06d5b2df
...
...
@@ -7,6 +7,7 @@ from .DCGAN import DCGAN_discriminator
from
.ConditionalGAN
import
ConditionalGAN_generator
from
.ConditionalGAN
import
ConditionalGAN_discriminator
from
.ConvAutoencoder
import
ConvAutoencoder
from
.utils
import
weights_init
...
...
bob/learn/pytorch/config/__init__.py
0 → 100644
View file @
06d5b2df
bob/learn/pytorch/config/autoencoder/__init__.py
0 → 100644
View file @
06d5b2df
bob/learn/pytorch/config/autoencoder/net1_celeba.py
0 → 100644
View file @
06d5b2df
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import here:
from
torchvision
import
transforms
from
bob.pad.face.database
import
CELEBAPadDatabase
from
torch
import
nn
#==============================================================================
# Define parameters here:
"""
Note: do not change names of the below constants.
"""
NUM_EPOCHS
=
70
# Maximum number of epochs
BATCH_SIZE
=
32
# Size of the batch
LEARNING_RATE
=
1e-3
# Learning rate
NUM_WORKERS
=
8
# The number of workers for the DataLoader
"""
Transformations to be applied sequentially to the input PIL image.
Note: the variable name ``transform`` must be the same in all configuration files.
"""
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
))
])
"""
Set the parameters of the DataFolder dataset class.
Note: do not change the name ``kwargs``.
"""
bob_hldi_instance
=
CELEBAPadDatabase
(
original_directory
=
""
,
original_extension
=
""
)
kwargs
=
{}
kwargs
[
"data_folder"
]
=
"NO NEED TO SET HERE, WILL BE SET IN THE TRAINING SCRIPT"
kwargs
[
"transform"
]
=
transform
kwargs
[
"extension"
]
=
'.hdf5'
kwargs
[
"bob_hldi_instance"
]
=
bob_hldi_instance
kwargs
[
"hldi_type"
]
=
"pad"
kwargs
[
"groups"
]
=
[
'train'
]
kwargs
[
"protocol"
]
=
'grandtest'
kwargs
[
"purposes"
]
=
[
'real'
]
kwargs
[
"allow_missing_files"
]
=
True
"""
Define the network to be trained as a class, named ``Network``.
Note: Do not change the name of the below class.
"""
from
bob.learn.pytorch.architectures
import
ConvAutoencoder
as
Network
"""
Define the loss to be used for training.
Note: do not change the name of the below variable.
"""
loss_type
=
nn
.
MSELoss
()
"""
OPTIONAL: if not defined loss will be computed in the training script.
See training script for details
Define the function to compute the loss. Don't change the signature of this
function.
"""
# we don't define the loss_function for this configuration
#def loss_function(output, img, target):
bob/learn/pytorch/datasets/__init__.py
View file @
06d5b2df
from
.casia_webface
import
CasiaDataset
from
.casia_webface
import
CasiaWebFaceDataset
from
.data_folder
import
DataFolder
# transforms
from
.utils
import
FaceCropper
...
...
bob/learn/pytorch/datasets/data_folder.py
0 → 100644
View file @
06d5b2df
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
@author: Olegs Nikisins
"""
#==============================================================================
# Import what is needed here:
import
torch.utils.data
as
data
import
os
import
random
random
.
seed
(
a
=
7
)
import
PIL
import
numpy
as
np
from
torchvision
import
transforms
import
torch
import
h5py
#==============================================================================
def
get_file_names_and_labels
(
files
,
data_folder
,
extension
=
".hdf5"
,
hldi_type
=
"pad"
):
"""
Get absolute names of the corresponding file objects and their class labels,
as well as keys defining name of the frame to load the data from.
**Parameters:**
``files`` : [File]
A list of files objects defined in the High Level Database Interface
of the particular datbase.
``data_folder`` : str
A directory containing the training data.
``extension`` : str
Extension of the data files. Default: ".hdf5" .
``hldi_type`` : str
Type of the high level database interface. Default: "pad".
Note: this is the only type supported at the moment.
**Returns:**
``file_names_labels_keys`` : [(str, int, str)]
A list of tuples, where each tuple contain an absolute filename,
a corresponding label of the class, and a key defining the name of the
frame to extract the data from.
"""
file_names_labels_keys
=
[]
if
hldi_type
==
"pad"
:
for
f
in
files
:
if
f
.
attack_type
is
None
:
label
=
0
else
:
label
=
1
file_name
=
os
.
path
.
join
(
data_folder
,
f
.
path
+
extension
)
if
os
.
path
.
isfile
(
file_name
):
# if file is available:
with
h5py
.
File
(
file_name
,
"r"
)
as
f_h5py
:
file_keys
=
list
(
f_h5py
.
keys
())
# elements of tuples in the below list are as follows:
# a filename a key is extracted from,
# a label corresponding to the file,
# a key defining a frame from the file.
file_names_labels_keys
=
file_names_labels_keys
+
[(
file_name
,
label
,
key
)
for
file_name
,
label
,
key
in
zip
([
file_name
]
*
len
(
file_keys
),
[
label
]
*
len
(
file_keys
),
file_keys
)]
return
file_names_labels_keys
#==============================================================================
class
DataFolder
(
data
.
Dataset
):
"""
A generic data loader compatible with Bob High Level Database Interfaces
(HLDI). Only HLDI's of bob.pad.face are currently supported.
"""
def
__init__
(
self
,
data_folder
,
transform
=
None
,
extension
=
'.hdf5'
,
bob_hldi_instance
=
None
,
hldi_type
=
"pad"
,
groups
=
[
'train'
,
'dev'
,
'eval'
],
protocol
=
'grandtest'
,
purposes
=
[
'real'
,
'attack'
],
allow_missing_files
=
True
,
**
kwargs
):
"""
**Parameters:**
``data_folder`` : str
A directory containing the training data.
``transform`` : callable
A function/transform that takes in a PIL image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
``extension`` : str
Extension of the data files. Default: ".hdf5".
Note: this is the only extension supported at the moment.
``bob_hldi_instance`` : object
An instance of the HLDI interface. Only HLDI's of bob.pad.face
are currently supported.
``hldi_type`` : str
String defining the type of the HLDI. Default: "pad".
Note: this is the only option currently supported.
``groups`` : str or [str]
The groups for which the clients should be returned.
Usually, groups are one or more elements of ['train', 'dev', 'eval'].
Default: ['train', 'dev', 'eval'].
``protocol`` : str
The protocol for which the clients should be retrieved.
Default: 'grandtest'.
``purposes`` : str or [str]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
Default: ['real', 'attack'].
``allow_missing_files`` : str or [str]
The missing files in the ``data_folder`` will not break the
execution if set to True.
Default: True.
"""
self
.
data_folder
=
data_folder
self
.
transform
=
transform
self
.
extension
=
extension
self
.
bob_hldi_instance
=
bob_hldi_instance
self
.
hldi_type
=
hldi_type
self
.
groups
=
groups
self
.
protocol
=
protocol
self
.
purposes
=
purposes
self
.
allow_missing_files
=
allow_missing_files
if
bob_hldi_instance
is
not
None
:
files
=
bob_hldi_instance
.
objects
(
groups
=
self
.
groups
,
protocol
=
self
.
protocol
,
purposes
=
self
.
purposes
,
**
kwargs
)
file_names_labels_keys
=
get_file_names_and_labels
(
files
=
files
,
data_folder
=
self
.
data_folder
,
extension
=
self
.
extension
,
hldi_type
=
self
.
hldi_type
)
if
self
.
allow_missing_files
:
# return only existing files
file_names_labels_keys
=
[
f
for
f
in
file_names_labels_keys
if
os
.
path
.
isfile
(
f
[
0
])]
else
:
# TODO - add behaviour similar to image folder
file_names_labels_keys
=
[]
self
.
file_names_labels_keys
=
file_names_labels_keys
#==========================================================================
def
__getitem__
(
self
,
index
):
"""
Returns an image, possibly transformed, and a target class given index.
**Parameters:**
``index`` : int.
An index of the sample to return.
**Returns:**
``pil_img`` : Tensor or PIL Image
If ``self.transform`` is defined the output is the torch.Tensor,
otherwise the output is an instance of the PIL.Image.Image class.
``target`` : int
Index of the class.
"""
path
,
target
,
key
=
self
.
file_names_labels_keys
[
index
]
with
h5py
.
File
(
path
,
"r"
)
as
f_h5py
:
img_array
=
np
.
array
(
f_h5py
.
get
(
key
+
'/array'
))
# The size now is (3 x W x H)
if
isinstance
(
self
.
transform
,
transforms
.
Compose
):
# if an instance of torchvision composed transformation
if
len
(
img_array
.
shape
)
==
3
:
# for color images
img_array_tr
=
np
.
swapaxes
(
img_array
,
1
,
2
)
img_array_tr
=
np
.
swapaxes
(
img_array_tr
,
0
,
2
)
pil_img
=
PIL
.
Image
.
fromarray
(
img_array_tr
)
# convert to PIL from array of size (H x W x 3)
else
:
# for gray-scale images
pil_img
=
PIL
.
Image
.
fromarray
(
img_array
,
'L'
)
# convert to PIL from array of size (H x W)
if
self
.
transform
is
not
None
:
pil_img
=
self
.
transform
(
pil_img
)
else
:
# if custom transformation function is given
img_array_transformed
=
self
.
transform
(
img_array
)
return
torch
.
Tensor
(
img_array_transformed
).
unsqueeze
(
0
),
target
# convert array to Tensor, also return target
return
pil_img
,
target
#==========================================================================
def
__len__
(
self
):
"""
**Returns:**
``len`` : int
The length of the file list.
"""
return
len
(
self
.
file_names_labels_keys
)
bob/learn/pytorch/scripts/pytorch_train.py
View file @
06d5b2df
...
...
@@ -31,7 +31,7 @@ import argparse
import
importlib
import
os
from
bob.
pad.face.database.pytorch
import
DataFolder
from
bob.
learn.pytorch.datasets
import
DataFolder
import
torch
from
torch.utils.data
import
DataLoader
...
...
@@ -91,10 +91,10 @@ def parse_arguments(cmd_params=None):
parser
.
add_argument
(
"-c"
,
"--config-file"
,
type
=
str
,
help
=
"Relative name of the config file defining "
"the network, training data, and training parameters."
,
default
=
"autoencoder/
autoencoder_config
.py"
)
default
=
"autoencoder/
net1_celeba
.py"
)
parser
.
add_argument
(
"-cg"
,
"--config-group"
,
type
=
str
,
help
=
"Name of the group, where config file is stored."
,
default
=
"bob.
pad.face.config.pytorch
"
)
default
=
"bob.
learn.pytorch.config
"
)
parser
.
add_argument
(
"-p"
,
"--pretrained-model-path"
,
type
=
str
,
help
=
"Absolute name of the file, containing pre-trained Network "
"model, to de used for Network initialization before training."
,
...
...
setup.py
View file @
06d5b2df
...
...
@@ -72,7 +72,7 @@ setup(
'train_cnn.py = bob.learn.pytorch.scripts.train_cnn:main'
,
'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main'
,
'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main'
,
'pytorch-train-autoencoder-pad.py = bob.
pad.face.script.pytorch
.pytorch_train:main'
,
'pytorch-train-autoencoder-pad.py = bob.
learn.pytorch.scripts
.pytorch_train:main'
,
],
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment