Skip to content
Snippets Groups Projects
Commit 884dec6d authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Anjith GEORGE
Browse files

Small modifications in MCCNN and FASNet trainer

parent 05ed6c03
Branches
Tags
1 merge request!22Cross validation
......@@ -43,7 +43,7 @@ class MCCNN(nn.Module):
The path to download the pretrained LightCNN model from.
"""
def __init__(self, block=resblock, layers=[1, 2, 3, 4], num_channels=4, verbosity_level=2):
def __init__(self, block=resblock, layers=[1, 2, 3, 4], num_channels=4, verbosity_level=2, use_sigmoid=True):
""" Init function
Parameters
......@@ -51,6 +51,9 @@ class MCCNN(nn.Module):
num_channels: int
The number of channels present in the input
use_sigmoid: bool
Whether to use sigmoid in eval phase. If set to `False` do not use
sigmoid in eval phase. Training phase is not affected.
verbosity_level: int
Verbosity level.
......@@ -58,6 +61,7 @@ class MCCNN(nn.Module):
super(MCCNN, self).__init__()
self.num_channels=num_channels
self.use_sigmoid=use_sigmoid
self.lcnn_layers=['conv1','block1','group1','block2', 'group2','block3','group3','block4','group4','fc']
......@@ -211,9 +215,9 @@ class MCCNN(nn.Module):
output = self.linear2fc(output)
#if self.training:
if self.training or self.use_sigmoid:
output=nn.Sigmoid()(output)
output=nn.Sigmoid()(output)
return output
......
......@@ -178,7 +178,7 @@ class FASNetTrainer(object):
# setup optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.network.parameters()),lr = learning_rate )
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.network.parameters()),lr = learning_rate, weight_decay=0.000001)
self.network.train(True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment