Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 19-dataset-handling
  • bob8_test
  • bob_9_datafolder
  • drgan
  • generic_image_extractor
  • light
  • master
  • no-pad-base
  • swin_transformer
  • test_conda
  • tfrecord
  • v0.0.1
  • v0.0.2
  • v0.0.3
  • v0.0.4
  • v0.0.5
  • v0.1.0
  • v0.1.1
  • v0.1.2
  • v0.1.3b0
  • v0.2.0
21 results

Target

Select target project
  • bob/bob.learn.pytorch
1 result
Select Git revision
  • 19-dataset-handling
  • bob8_test
  • bob_9_datafolder
  • drgan
  • generic_image_extractor
  • light
  • master
  • no-pad-base
  • swin_transformer
  • test_conda
  • tfrecord
  • v0.0.1
  • v0.0.2
  • v0.0.3
  • v0.0.4
  • v0.0.5
  • v0.1.0
  • v0.1.1
  • v0.1.2
  • v0.1.3b0
  • v0.2.0
21 results
Show changes
Commits on Source (6)
...@@ -15,3 +15,4 @@ log* ...@@ -15,3 +15,4 @@ log*
results* results*
build/ build/
record.txt record.txt
*lightcnn*
...@@ -57,10 +57,11 @@ class ConvAutoencoder(nn.Module): ...@@ -57,10 +57,11 @@ class ConvAutoencoder(nn.Module):
The forward method. The forward method.
""" """
x = self.encoder(x) x = self.encoder(x)
x = self.decoder(x)
if self.return_latent_embedding: if self.return_latent_embedding:
return self.encoder(x) return x
x = self.decoder(x)
return x return x
...@@ -7,12 +7,14 @@ import torch.nn.functional as F ...@@ -7,12 +7,14 @@ import torch.nn.functional as F
from .utils import MaxFeatureMap from .utils import MaxFeatureMap
from .utils import group from .utils import group
from .utils import resblock
class LightCNN9(nn.Module): class LightCNN9(nn.Module):
""" The class defining the light CNN with 9 layers """ The class defining the light CNN with 9 layers
This class implements the CNN described in: This class implements the CNN described in:
"Learning Face Representation From Scratch", D. Yi, Z. Lei, S. Liao and S.z. Li, 2014 "A light CNN for deep face representation with noisy labels", Wu, Xiang and He, Ran and Sun, Zhenan and Tan, Tieniu,
IEEE Transactions on Information Forensics and Security, vol 13, issue 11, 2018
Attributes Attributes
---------- ----------
...@@ -74,3 +76,175 @@ class LightCNN9(nn.Module): ...@@ -74,3 +76,175 @@ class LightCNN9(nn.Module):
out = self.fc2(x) out = self.fc2(x)
return out, x return out, x
class LightCNN29(nn.Module):
""" The class defining the light CNN with 29 layers
This class implements the CNN described in:
"A light CNN for deep face representation with noisy labels", Wu, Xiang and He, Ran and Sun, Zhenan and Tan, Tieniu,
IEEE Transactions on Information Forensics and Security, vol 13, issue 11, 2018
Attributes
----------
"""
def __init__(self, block=resblock, layers=[1, 2, 3, 4], num_classes=79077):
""" Init function
Parameters
----------
num_classes: int
The number of classes.
"""
super(LightCNN29, self).__init__()
self.conv1 = MaxFeatureMap(1, 48, 5, 1, 2)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.block1 = self._make_layer(block, layers[0], 48, 48)
self.group1 = group(48, 96, 3, 1, 1)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.block2 = self._make_layer(block, layers[1], 96, 96)
self.group2 = group(96, 192, 3, 1, 1)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.block3 = self._make_layer(block, layers[2], 192, 192)
self.group3 = group(192, 128, 3, 1, 1)
self.block4 = self._make_layer(block, layers[3], 128, 128)
self.group4 = group(128, 128, 3, 1, 1)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.fc = MaxFeatureMap(8*8*128, 256, type=0)
self.fc2 = nn.Linear(256, num_classes)
def _make_layer(self, block, num_blocks, in_channels, out_channels):
"""
Parameters
----------
"""
layers = []
for i in range(0, num_blocks):
layers.append(block(in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
""" Propagate data through the network
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the network. Image of size 1x128x128
Returns
-------
out: :py:class:`torch.Tensor`
class probabilities
x: :py:class:`torch.Tensor`
Output of the penultimate layer (i.e. embedding)
"""
x = self.conv1(x)
x = self.pool1(x)
x = self.block1(x)
x = self.group1(x)
x = self.pool2(x)
x = self.block2(x)
x = self.group2(x)
x = self.pool3(x)
x = self.block3(x)
x = self.group3(x)
x = self.block4(x)
x = self.group4(x)
x = self.pool4(x)
x = x.view(x.size(0), -1)
fc = self.fc(x)
fc = F.dropout(fc, training=self.training)
out = self.fc2(fc)
return out, fc
class LightCNN29v2(nn.Module):
""" The class defining the light CNN with 29 layers (version 2)
This class implements the CNN described in:
"A light CNN for deep face representation with noisy labels", Wu, Xiang and He, Ran and Sun, Zhenan and Tan, Tieniu,
IEEE Transactions on Information Forensics and Security, vol 13, issue 11, 2018
Attributes
----------
"""
def __init__(self, block=resblock, layers=[1, 2, 3, 4], num_classes=79077):
""" Init function
Parameters
----------
num_classes: int
The number of classes.
"""
super(LightCNN29v2, self).__init__()
self.conv1 = MaxFeatureMap(1, 48, 5, 1, 2)
self.block1 = self._make_layer(block, layers[0], 48, 48)
self.group1 = group(48, 96, 3, 1, 1)
self.block2 = self._make_layer(block, layers[1], 96, 96)
self.group2 = group(96, 192, 3, 1, 1)
self.block3 = self._make_layer(block, layers[2], 192, 192)
self.group3 = group(192, 128, 3, 1, 1)
self.block4 = self._make_layer(block, layers[3], 128, 128)
self.group4 = group(128, 128, 3, 1, 1)
self.fc = nn.Linear(8*8*128, 256)
self.fc2 = nn.Linear(256, num_classes, bias=False)
def _make_layer(self, block, num_blocks, in_channels, out_channels):
"""
Parameters
----------
"""
layers = []
for i in range(0, num_blocks):
layers.append(block(in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
""" Propagate data through the network
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the network. Image of size 1x128x128
Returns
-------
out: :py:class:`torch.Tensor`
class probabilities
x: :py:class:`torch.Tensor`
Output of the penultimate layer (i.e. embedding)
"""
x = self.conv1(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = self.block1(x)
x = self.group1(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = self.block2(x)
x = self.group2(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = self.block3(x)
x = self.group3(x)
x = self.block4(x)
x = self.group4(x)
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
x = x.view(x.size(0), -1)
fc = self.fc(x)
x = F.dropout(fc, training=self.training)
out = self.fc2(x)
return out, fc
from .CNN8 import CNN8 from .CNN8 import CNN8
from .CASIANet import CASIANet from .CASIANet import CASIANet
from .LightCNN import LightCNN9 from .LightCNN import LightCNN9
from .LightCNN import LightCNN29
from .LightCNN import LightCNN29v2
from .DCGAN import DCGAN_generator from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator from .DCGAN import DCGAN_discriminator
......
...@@ -163,3 +163,19 @@ class group(nn.Module): ...@@ -163,3 +163,19 @@ class group(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class resblock(nn.Module):
""" Class implementing ...
"""
def __init__(self, in_channels, out_channels):
super(resblock, self).__init__()
self.conv1 = MaxFeatureMap(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = MaxFeatureMap(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
res = x
out = self.conv1(x)
out = self.conv2(out)
out = out + res
return out
...@@ -41,6 +41,24 @@ def test_architectures(): ...@@ -41,6 +41,24 @@ def test_architectures():
output, emdedding = net.forward(t) output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 79077]) assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256]) assert emdedding.shape == torch.Size([1, 256])
# LightCNN29
a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import LightCNN29
net = LightCNN29()
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256])
# LightCNN29v2
a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import LightCNN29v2
net = LightCNN29v2()
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256])
# DCGAN # DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32") d = numpy.random.rand(1, 3, 64, 64).astype("float32")
......
...@@ -70,9 +70,11 @@ class CNNTrainer(object): ...@@ -70,9 +70,11 @@ class CNNTrainer(object):
""" """
try: try:
cp = torch.load(model_filename) cp = torch.load(model_filename)
logger.info("model {} loaded".format(model_filename))
except RuntimeError: except RuntimeError:
# pre-trained model was probably saved using nn.DataParallel ... # pre-trained model was probably saved using nn.DataParallel ...
cp = torch.load(model_filename, map_location='cpu') cp = torch.load(model_filename, map_location='cpu')
logger.info("model {} loaded on CPU".format(model_filename))
if 'state_dict' in cp: if 'state_dict' in cp:
from collections import OrderedDict from collections import OrderedDict
...@@ -82,11 +84,14 @@ class CNNTrainer(object): ...@@ -82,11 +84,14 @@ class CNNTrainer(object):
new_state_dict[name] = v new_state_dict[name] = v
cp['state_dict'] = new_state_dict cp['state_dict'] = new_state_dict
logger.info("state_dict modified")
########################################################################################################### ###########################################################################################################
### for each defined architecture, get the output size in pre-trained model, and change it if necessary ### ### for each defined architecture, get the output size in pre-trained model, and change it if necessary ###
# LightCNN9 # LightCNN
if isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN9): if isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN9) \
or isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN29) \
or isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN29v2):
last_layer_weight = 'fc2.weight' last_layer_weight = 'fc2.weight'
last_layer_bias = 'fc2.bias' last_layer_bias = 'fc2.bias'
...@@ -99,9 +104,10 @@ class CNNTrainer(object): ...@@ -99,9 +104,10 @@ class CNNTrainer(object):
var = 1.0 / (cp['state_dict'][last_layer_weight].shape[0]) var = 1.0 / (cp['state_dict'][last_layer_weight].shape[0])
np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict'][last_layer_weight].shape[1])) np_weights = numpy.random.normal(loc=0.0, scale=var, size=((self.num_classes+1), cp['state_dict'][last_layer_weight].shape[1]))
cp['state_dict'][last_layer_weight] = torch.from_numpy(np_weights) cp['state_dict'][last_layer_weight] = torch.from_numpy(np_weights)
cp['state_dict'][last_layer_bias] = torch.zeros(((self.num_classes+1),)) if not (isinstance(self.network, bob.learn.pytorch.architectures.LightCNN.LightCNN29v2)):
#self.network.load_state_dict(cp['state_dict'], strict=False) cp['state_dict'][last_layer_bias] = torch.zeros(((self.num_classes+1),))
self.network.load_state_dict(cp['state_dict'], strict=True) self.network.load_state_dict(cp['state_dict'], strict=True)
logger.info("state_dict loaded for {} with {} classes".format(type(self.network), self.num_classes))
# CNN8 # CNN8
if isinstance(self.network, bob.learn.pytorch.architectures.CNN8): if isinstance(self.network, bob.learn.pytorch.architectures.CNN8):
...@@ -116,6 +122,7 @@ class CNNTrainer(object): ...@@ -116,6 +122,7 @@ class CNNTrainer(object):
cp['state_dict']['classifier.bias'] = torch.zeros(((self.num_classes+1),)) cp['state_dict']['classifier.bias'] = torch.zeros(((self.num_classes+1),))
#self.network.load_state_dict(cp['state_dict'], strict=False) #self.network.load_state_dict(cp['state_dict'], strict=False)
self.network.load_state_dict(cp['state_dict'], strict=True) self.network.load_state_dict(cp['state_dict'], strict=True)
logger.info("state_dict loaded for {} with {} classes".format(type(self.network), self.num_classes))
# CASIANet # CASIANet
if isinstance(self.network, bob.learn.pytorch.architectures.CASIANet): if isinstance(self.network, bob.learn.pytorch.architectures.CASIANet):
...@@ -130,6 +137,7 @@ class CNNTrainer(object): ...@@ -130,6 +137,7 @@ class CNNTrainer(object):
cp['state_dict']['classifier.bias'] = torch.zeros(((self.num_classes+1),)) cp['state_dict']['classifier.bias'] = torch.zeros(((self.num_classes+1),))
#self.network.load_state_dict(cp['state_dict'], strict=False) #self.network.load_state_dict(cp['state_dict'], strict=False)
self.network.load_state_dict(cp['state_dict'], strict=True) self.network.load_state_dict(cp['state_dict'], strict=True)
logger.info("state_dict loaded for {} with {} classes".format(type(self.network), self.num_classes))
########################################################################################################### ###########################################################################################################
......