Skip to content
Snippets Groups Projects
Commit 9944b123 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[utils] added docstrings

parent e5305554
Branches
Tags v4.0.2
No related merge requests found
...@@ -3,14 +3,30 @@ import torch.nn as nn ...@@ -3,14 +3,30 @@ import torch.nn as nn
def make_conv_layers(cfg, input_c = 3): def make_conv_layers(cfg, input_c = 3):
layers = [] """ builds the convolution / max pool layers
in_channels = input_c
for v in cfg: The network architecture is provided as a list, containing the
if v == 'M': number of feature maps, or a 'M' for a MaxPooling layer.
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else: Example for Casia-Net:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) [32, 64, 'M', 64, 128, 'M', 96, 192, 'M', 128, 256, 'M', 160, 320]
layers += [conv2d, nn.ReLU()]
in_channels = v Parameters
return nn.Sequential(*layers) ----------
cfg: list
Configuration for the network (see above)
input_c: int
The number of channels in the input (1 -> gray, 3 -> rgb)
"""
layers = []
in_channels = input_c
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.Sequential(*layers)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment