Commit cbb7d345 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

allow changing batch norm params

parent dfa48e09
......@@ -263,7 +263,13 @@ from tensorflow.keras.layers import Flatten
def add_bottleneck(
model, bottleneck_size=128, dropout_rate=0.2, w_decay=5e-4, use_bias=True
model,
bottleneck_size=128,
dropout_rate=0.2,
w_decay=5e-4,
use_bias=True,
batch_norm_decay=0.99,
batch_norm_epsilon=1e-3,
):
"""
Amend a bottleneck layer to a Keras Model
......@@ -286,7 +292,9 @@ def add_bottleneck(
else:
new_model = model
new_model.add(BatchNormalization())
new_model.add(
BatchNormalization(momentum=batch_norm_decay, epsilon=batch_norm_epsilon)
)
new_model.add(Dropout(dropout_rate, name="Dropout"))
new_model.add(Flatten())
......@@ -303,7 +311,14 @@ def add_bottleneck(
)
)
new_model.add(BatchNormalization(axis=-1, name="embeddings"))
new_model.add(
BatchNormalization(
axis=-1,
name="embeddings",
momentum=batch_norm_decay,
epsilon=batch_norm_epsilon,
)
)
# new_model.add(BatchNormalization())
return new_model
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment