Commit b315607e authored by M. François's avatar M. François

consistent init

parent 5202a60f
...@@ -53,13 +53,13 @@ class NeuralFilter(torch.nn.Module): ...@@ -53,13 +53,13 @@ class NeuralFilter(torch.nn.Module):
parts = self.hidden_size * 2 parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2) ranges = np.arange(1, parts, 2)
init_modulus = ranges * (max_modulus - min_modulus) / parts + min_modulus init = ranges * (max_modulus - min_modulus) / parts + min_modulus
init = asig(init_modulus)
if not isinstance(init, np.ndarray): if not isinstance(init, np.ndarray):
init = np.array(init, ndmin=1) init = np.array(init, ndmin=1)
ten_init = torch.from_numpy(init) init_modulus = asig(init)
ten_init = torch.from_numpy(init_modulus)
self.bias_forget.data.copy_(ten_init) self.bias_forget.data.copy_(ten_init)
def __repr__(self): def __repr__(self):
......
...@@ -54,27 +54,28 @@ class NeuralFilter2CC(torch.nn.Module): ...@@ -54,27 +54,28 @@ class NeuralFilter2CC(torch.nn.Module):
min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS): min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS):
if init_modulus is None: if init_modulus is None:
init_modulus = asig(modulus) init_modulus = modulus
if not isinstance(init_modulus, np.ndarray): if not isinstance(init_modulus, np.ndarray):
init_modulus = np.array(init_modulus, ndmin=1) init_modulus = np.array(init_modulus, ndmin=1)
ten_init = torch.from_numpy(init_modulus) init_mod = asig(init_modulus)
ten_init = torch.from_numpy(init_mod)
self.bias_modulus.data.copy_(ten_init) self.bias_modulus.data.copy_(ten_init)
if init_theta is None: if init_theta is None:
parts = self.hidden_size * 2 parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2) ranges = np.arange(1, parts, 2)
init_angle = ranges * (max_angle - min_angle) / parts + min_angle init_theta = ranges * (max_angle - min_angle) / parts + min_angle
cosangle = np.cos(init_angle)
init_theta = atanh(cosangle)
if not isinstance(init_theta, np.ndarray): if not isinstance(init_theta, np.ndarray):
init_theta = np.array(init_theta, ndmin=1) init_theta = np.array(init_theta, ndmin=1)
ten_init = torch.from_numpy(init_theta) cosangle = np.cos(init_theta)
init_angle = atanh(cosangle)
ten_init = torch.from_numpy(init_angle)
self.bias_theta.data.copy_(ten_init) self.bias_theta.data.copy_(ten_init)
def __repr__(self): def __repr__(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment