Skip to content
Snippets Groups Projects
Commit b315607e authored by M. François's avatar M. François
Browse files

consistent init

parent 5202a60f
No related branches found
No related tags found
No related merge requests found
...@@ -53,13 +53,13 @@ class NeuralFilter(torch.nn.Module): ...@@ -53,13 +53,13 @@ class NeuralFilter(torch.nn.Module):
parts = self.hidden_size * 2 parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2) ranges = np.arange(1, parts, 2)
init_modulus = ranges * (max_modulus - min_modulus) / parts + min_modulus init = ranges * (max_modulus - min_modulus) / parts + min_modulus
init = asig(init_modulus)
if not isinstance(init, np.ndarray): if not isinstance(init, np.ndarray):
init = np.array(init, ndmin=1) init = np.array(init, ndmin=1)
ten_init = torch.from_numpy(init) init_modulus = asig(init)
ten_init = torch.from_numpy(init_modulus)
self.bias_forget.data.copy_(ten_init) self.bias_forget.data.copy_(ten_init)
def __repr__(self): def __repr__(self):
......
...@@ -54,27 +54,28 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -54,27 +54,28 @@ class NeuralFilter2CC(torch.nn.Module):
min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS): min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS):
if init_modulus is None: if init_modulus is None:
init_modulus = asig(modulus) init_modulus = modulus
if not isinstance(init_modulus, np.ndarray): if not isinstance(init_modulus, np.ndarray):
init_modulus = np.array(init_modulus, ndmin=1) init_modulus = np.array(init_modulus, ndmin=1)
ten_init = torch.from_numpy(init_modulus) init_mod = asig(init_modulus)
ten_init = torch.from_numpy(init_mod)
self.bias_modulus.data.copy_(ten_init) self.bias_modulus.data.copy_(ten_init)
if init_theta is None: if init_theta is None:
parts = self.hidden_size * 2 parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2) ranges = np.arange(1, parts, 2)
init_angle = ranges * (max_angle - min_angle) / parts + min_angle init_theta = ranges * (max_angle - min_angle) / parts + min_angle
cosangle = np.cos(init_angle)
init_theta = atanh(cosangle)
if not isinstance(init_theta, np.ndarray): if not isinstance(init_theta, np.ndarray):
init_theta = np.array(init_theta, ndmin=1) init_theta = np.array(init_theta, ndmin=1)
ten_init = torch.from_numpy(init_theta) cosangle = np.cos(init_theta)
init_angle = atanh(cosangle)
ten_init = torch.from_numpy(init_angle)
self.bias_theta.data.copy_(ten_init) self.bias_theta.data.copy_(ten_init)
def __repr__(self): def __repr__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment