Commit c076b423 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

finished implementation of LightCNN9

parent f2fd52e4
Pipeline #26303 passed with stage
in 7 minutes and 1 second
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
......
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import MaxFeatureMap
from .utils import group
class LightCNN9(nn.Module):
""" The class defining the light CNN with 9 layers
......@@ -43,23 +48,24 @@ class LightCNN9(nn.Module):
group(128, 128, 3, 1, 1),
nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
)
self.fc1 = mfm(8*8*128, 256, type=0)
self.fc1 = MaxFeatureMap(8*8*128, 256, type=0)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
def forward(self, x):
""" Propagate data through the network
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the network
The data to forward through the network. Image of size 1x128x128
Returns
-------
out: :py:class:`torch.Tensor`
class probabilities
x: :py:class:`torch.Tensor`
Output of the penultimate layer (i.e. embedding)
"""
x = self.features(x)
x = x.view(x.size(0), -1)
......
from .CNN8 import CNN8
from .CASIANet import CASIANet
from .LightCNN import LightCNN9
from .DCGAN import DCGAN_generator
from .DCGAN import DCGAN_discriminator
......
#!/usr/bin/env python
# encoding: utf-8
import torch
import torch.nn as nn
......@@ -110,3 +114,52 @@ class MaxFeatureMap(nn.Module):
out = torch.split(x, self.out_channels, 1)
return torch.max(out[0], out[1])
class group(nn.Module):
""" Class implementing ...
Attributes
----------
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
""" Init function
Parameters
----------
in_channels: int
the number of input channels
out_channels: int
the number of output channels
kernel_size: int
The size of the kernel in the convolution
stride: int
The stride in the convolution
padding: int
The padding (default to)
"""
super(group, self).__init__()
self.conv_a = MaxFeatureMap(in_channels, in_channels, 1, 1, 0)
self.conv = MaxFeatureMap(in_channels, out_channels, kernel_size, stride, padding)
def forward(self, x):
""" Forward function
Propagates data through the Max Feature Map
Parameters
----------
x: :py:class:`torch.Tensor`
The data to forward through the MFM
Returns
-------
py:class:`torch.Tensor`
"""
x = self.conv_a(x)
x = self.conv(x)
return x
......@@ -33,6 +33,15 @@ def test_architectures():
assert output.shape == torch.Size([1, 20])
assert emdedding.shape == torch.Size([1, 512])
# LightCNN9
a = numpy.random.rand(1, 1, 128, 128).astype("float32")
t = torch.from_numpy(a)
from ..architectures import LightCNN9
net = LightCNN9()
output, emdedding = net.forward(t)
assert output.shape == torch.Size([1, 79077])
assert emdedding.shape == torch.Size([1, 256])
# DCGAN
d = numpy.random.rand(1, 3, 64, 64).astype("float32")
t = torch.from_numpy(d)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment