Skip to content
Snippets Groups Projects
Commit 8881e90f authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Sigmoid is now optional in the TwoLayerMLP

parent bc460d80
Branches
Tags
1 merge request!14MLP class and config to train it
......@@ -14,10 +14,10 @@ import torch.nn.functional as F
#==============================================================================
# Define the network:
class TwoLayerMLP(nn.Module):
"""
A simple two-layer MLP for binary classification.
A simple two-layer MLP for binary classification. The output activation
function is sigmoid.
Attributes
----------
......@@ -26,9 +26,13 @@ class TwoLayerMLP(nn.Module):
n_hidden_relu : int
Number of ReLU units in the hidden layer of the MLP.
apply_sigmoid : bool
If set to ``True`` the sigmoid will be applied to the output of the
hidden FC layer. If ``False`` the sigmoid is not applied.
"""
def __init__(self, in_features, n_hidden_relu):
def __init__(self, in_features, n_hidden_relu, apply_sigmoid = True):
super(TwoLayerMLP, self).__init__()
"""
Init method.
......@@ -38,10 +42,13 @@ class TwoLayerMLP(nn.Module):
self.n_hidden_relu = n_hidden_relu
self.apply_sigmoid = apply_sigmoid
self.fc1 = nn.Linear(in_features = self.in_features, out_features = self.n_hidden_relu, bias=True)
self.fc2 = nn.Linear(in_features = self.n_hidden_relu, out_features = 1, bias=True)
def forward(self, x):
"""
The forward method.
......@@ -57,6 +64,11 @@ class TwoLayerMLP(nn.Module):
# second fully connected activated by sigmoid:
x = self.fc2(x)
if not self.apply_sigmoid:
return x
x = F.sigmoid(x)
return x
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment