Skip to content
Snippets Groups Projects
Commit c076b423 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

finished implementation of LightCNN9

parent f2fd52e4
No related branches found
No related tags found
1 merge request!9Light cnn
Pipeline #26303 passed
#!/usr/bin/env python
# encoding: utf-8
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
......
#!/usr/bin/env python #!/usr/bin/env python
# encoding: utf-8 # encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import MaxFeatureMap from .utils import MaxFeatureMap
from .utils import group
class LightCNN9(nn.Module): class LightCNN9(nn.Module):
""" The class defining the light CNN with 9 layers """ The class defining the light CNN with 9 layers
...@@ -43,23 +48,24 @@ class LightCNN9(nn.Module): ...@@ -43,23 +48,24 @@ class LightCNN9(nn.Module):
group(128, 128, 3, 1, 1), group(128, 128, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True), nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
) )
self.fc1 = mfm(8*8*128, 256, type=0) self.fc1 = MaxFeatureMap(8*8*128, 256, type=0)
self.fc2 = nn.Linear(256, num_classes) self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
def forward(self, x):
""" Propagate data through the network """ Propagate data through the network
Parameters Parameters
---------- ----------
x: :py:class:`torch.Tensor` x: :py:class:`torch.Tensor`
The data to forward through the network The data to forward through the network. Image of size 1x128x128
Returns Returns
------- -------
out: :py:class:`torch.Tensor` out: :py:class:`torch.Tensor`
class probabilities
x: :py:class:`torch.Tensor` x: :py:class:`torch.Tensor`
Output of the penultimate layer (i.e. embedding)
""" """
x = self.features(x) x = self.features(x)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
......
from .CNN8 import CNN8 from .CNN8 import CNN8
from .CASIANet import CASIANet from .CASIANet import CASIANet
from .LightCNN import LightCNN9
from .DCGAN import DCGAN_generator from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator from .DCGAN import DCGAN_discriminator
......
#!/usr/bin/env python
# encoding: utf-8
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -110,3 +114,52 @@ class MaxFeatureMap(nn.Module): ...@@ -110,3 +114,52 @@ class MaxFeatureMap(nn.Module):
out = torch.split(x, self.out_channels, 1) out = torch.split(x, self.out_channels, 1)
return torch.max(out[0], out[1]) return torch.max(out[0], out[1])
class group(nn.Module):
""" Class implementing ...
Attributes
----------
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
""" Init function
Parameters
----------
in_channels: int
the number of input channels
out_channels: int
the number of output channels
kernel_size: int
The size of the kernel in the convolution
stride: int
The stride in the convolution
padding: int
The padding (default to)
"""
super(group, self).__init__()
self.conv_a = MaxFeatureMap(in_channels, in_channels, 1, 1, 0)
self.conv = MaxFeatureMap(in_channels, out_channels, kernel_size, stride, padding)
def forward(self, x):
""" Forward function
Propagates data through the Max Feature Map
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the MFM
Returns
-------
py:class:`torch.Tensor`
"""
x = self.conv_a(x)
x = self.conv(x)
return x
...@@ -33,6 +33,15 @@ def test_architectures(): ...@@ -33,6 +33,15 @@ def test_architectures():
assert output.shape == torch.Size([1, 20]) assert output.shape == torch.Size([1, 20])
assert emdedding.shape == torch.Size([1, 512]) assert emdedding.shape == torch.Size([1, 512])
# LightCNN9
a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import LightCNN9
net = LightCNN9()
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256])
# DCGAN # DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32") d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d) t = torch.from_numpy(d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment