Commit b315607e authored by M. François's avatar M. François

consistent init

parent 5202a60f
......@@ -53,13 +53,13 @@ class NeuralFilter(torch.nn.Module):
parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2)
init_modulus = ranges * (max_modulus - min_modulus) / parts + min_modulus
init = asig(init_modulus)
init = ranges * (max_modulus - min_modulus) / parts + min_modulus
if not isinstance(init, np.ndarray):
init = np.array(init, ndmin=1)
ten_init = torch.from_numpy(init)
init_modulus = asig(init)
ten_init = torch.from_numpy(init_modulus)
self.bias_forget.data.copy_(ten_init)
def __repr__(self):
......
......@@ -54,27 +54,28 @@ class NeuralFilter2CC(torch.nn.Module):
min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS):
if init_modulus is None:
init_modulus = asig(modulus)
init_modulus = modulus
if not isinstance(init_modulus, np.ndarray):
init_modulus = np.array(init_modulus, ndmin=1)
ten_init = torch.from_numpy(init_modulus)
init_mod = asig(init_modulus)
ten_init = torch.from_numpy(init_mod)
self.bias_modulus.data.copy_(ten_init)
if init_theta is None:
parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2)
init_angle = ranges * (max_angle - min_angle) / parts + min_angle
cosangle = np.cos(init_angle)
init_theta = atanh(cosangle)
init_theta = ranges * (max_angle - min_angle) / parts + min_angle
if not isinstance(init_theta, np.ndarray):
init_theta = np.array(init_theta, ndmin=1)
ten_init = torch.from_numpy(init_theta)
cosangle = np.cos(init_theta)
init_angle = atanh(cosangle)
ten_init = torch.from_numpy(init_angle)
self.bias_theta.data.copy_(ten_init)
def __repr__(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment