Skip to content
Snippets Groups Projects
Commit c4ad480f authored by M. François's avatar M. François
Browse files

modulus output in forward

parent cf0fc7d9
Branches
Tags
No related merge requests found
...@@ -125,7 +125,7 @@ class NeuralFilter(torch.nn.Module): ...@@ -125,7 +125,7 @@ class NeuralFilter(torch.nn.Module):
if is_packed: if is_packed:
output = PackedSequence(output, batch_sizes) output = PackedSequence(output, batch_sizes)
return output, nexth return output, nexth, a
@property @property
def gradients(self): def gradients(self):
...@@ -138,7 +138,7 @@ class NeuralFilter(torch.nn.Module): ...@@ -138,7 +138,7 @@ class NeuralFilter(torch.nn.Module):
@property @property
def denominator(self): def denominator(self):
forgetgate = F.sigmoid(self.bias_forget).data.numpy() forgetgate = F.sigmoid(self.bias_forget).detach().cpu().numpy()
forgetgate = forgetgate.reshape((forgetgate.size, 1)) forgetgate = forgetgate.reshape((forgetgate.size, 1))
one = np.ones(forgetgate.shape) one = np.ones(forgetgate.shape)
denom = np.concatenate((one, -forgetgate), axis=1) denom = np.concatenate((one, -forgetgate), axis=1)
......
...@@ -151,7 +151,7 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -151,7 +151,7 @@ class NeuralFilter2CC(torch.nn.Module):
if is_packed: if is_packed:
output = PackedSequence(output, batch_sizes) output = PackedSequence(output, batch_sizes)
return output, nexth return output, nexth, modulus
def print_param(self): def print_param(self):
modulus = F.sigmoid(self.bias_modulus) modulus = F.sigmoid(self.bias_modulus)
...@@ -166,8 +166,8 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -166,8 +166,8 @@ class NeuralFilter2CC(torch.nn.Module):
cosangle = F.tanh(self.bias_theta) cosangle = F.tanh(self.bias_theta)
p1 = -2 * cosangle * modulus p1 = -2 * cosangle * modulus
p2 = modulus.pow(2) p2 = modulus.pow(2)
p1 = p1.data.numpy() p1 = p1.detach().cpu().numpy()
p2 = p2.data.numpy() p2 = p2.detach().cpu().numpy()
p1 = p1.reshape(p1.size, 1) p1 = p1.reshape(p1.size, 1)
p2 = p2.reshape(p2.size, 1) p2 = p2.reshape(p2.size, 1)
one = np.ones(p1.shape) one = np.ones(p1.shape)
......
...@@ -53,10 +53,10 @@ class NeuralFilter2CD(torch.nn.Module): ...@@ -53,10 +53,10 @@ class NeuralFilter2CD(torch.nn.Module):
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input_var, hx=(None, None)): def forward(self, input_var, hx=(None, None)):
inter, inter_hidden = self.cell(input_var, hx[0]) inter, inter_hidden, modulus = self.cell(input_var, hx[0])
output, hidden = self.cell(inter, hx[1]) output, hidden, modulus = self.cell(inter, hx[1])
return output, (inter_hidden, hidden) return output, (inter_hidden, hidden), modulus
@property @property
def denominator(self): def denominator(self):
......
...@@ -65,10 +65,10 @@ class NeuralFilter2R(torch.nn.Module): ...@@ -65,10 +65,10 @@ class NeuralFilter2R(torch.nn.Module):
return s.format(name=self.__class__.__name__, **self.__dict__) return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, input_var, hx=(None, None)): def forward(self, input_var, hx=(None, None)):
interm, interm_hidden = self.first_cell(input_var, hx[0]) interm, interm_hidden, first_modulus = self.first_cell(input_var, hx[0])
output, hidden = self.second_cell(interm, hx[1]) output, hidden, second_modulus = self.second_cell(interm, hx[1])
return output, (interm_hidden, hidden) return output, (interm_hidden, hidden), (first_modulus, second_modulus)
@property @property
def denominator(self): def denominator(self):
......
...@@ -2,7 +2,7 @@ from setuptools import setup, find_packages ...@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='neural-filters', name='neural-filters',
version='1.0', version='1.1',
description='Linear filters for neural networks in pyTorch', description='Linear filters for neural networks in pyTorch',
author='Idiap research institute - Francois Marelli', author='Idiap research institute - Francois Marelli',
author_email='francois.marelli@idiap.ch', author_email='francois.marelli@idiap.ch',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment