Skip to content
Snippets Groups Projects
Commit b315607e authored by M. François's avatar M. François
Browse files

consistent init

parent 5202a60f
No related branches found
No related tags found
No related merge requests found
......@@ -53,13 +53,13 @@ class NeuralFilter(torch.nn.Module):
parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2)
init_modulus = ranges * (max_modulus - min_modulus) / parts + min_modulus
init = asig(init_modulus)
init = ranges * (max_modulus - min_modulus) / parts + min_modulus
if not isinstance(init, np.ndarray):
init = np.array(init, ndmin=1)
ten_init = torch.from_numpy(init)
init_modulus = asig(init)
ten_init = torch.from_numpy(init_modulus)
self.bias_forget.data.copy_(ten_init)
def __repr__(self):
......
......@@ -54,27 +54,28 @@ class NeuralFilter2CC(torch.nn.Module):
min_angle=MIN_ANGLE, max_angle=MAX_ANGLE, modulus=INIT_MODULUS):
if init_modulus is None:
init_modulus = asig(modulus)
init_modulus = modulus
if not isinstance(init_modulus, np.ndarray):
init_modulus = np.array(init_modulus, ndmin=1)
ten_init = torch.from_numpy(init_modulus)
init_mod = asig(init_modulus)
ten_init = torch.from_numpy(init_mod)
self.bias_modulus.data.copy_(ten_init)
if init_theta is None:
parts = self.hidden_size * 2
ranges = np.arange(1, parts, 2)
init_angle = ranges * (max_angle - min_angle) / parts + min_angle
cosangle = np.cos(init_angle)
init_theta = atanh(cosangle)
init_theta = ranges * (max_angle - min_angle) / parts + min_angle
if not isinstance(init_theta, np.ndarray):
init_theta = np.array(init_theta, ndmin=1)
ten_init = torch.from_numpy(init_theta)
cosangle = np.cos(init_theta)
init_angle = atanh(cosangle)
ten_init = torch.from_numpy(init_angle)
self.bias_theta.data.copy_(ten_init)
def __repr__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment