Commit 68ae48c1 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] moved the function to init the weights into utils

parent 7c489023
......@@ -3,26 +3,7 @@
import torch
import torch.nn as nn
def weights_init(m):
""" Initialize the weights
Initialize the weights in the different layers of
the network.
Parameters
----------
m : :py:class:`torch.nn.Conv2d`
The layer to initialize
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
from .utils import weights_init
class DCGAN_generator(nn.Module):
......
......@@ -30,3 +30,24 @@ def make_conv_layers(cfg, input_c = 3):
in_channels = v
return nn.Sequential(*layers)
def weights_init(m):
""" Initialize the weights
Initialize the weights in the different layers of
the network.
Parameters
----------
m : :py:class:`torch.nn.Conv2d`
The layer to initialize
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment