Made the pretraining network optional

parent 8ddc6018
Pipeline #47999 passed with stage
in 43 minutes and 11 seconds
......@@ -173,12 +173,19 @@ def create_model(
)
embeddings = pre_model.get_layer("embeddings").output
labels = tf.keras.layers.Input([], name="label")
if pre_train:
pre_model = add_top(pre_model, n_classes=n_classes)
# pre_model = add_top(pre_model, n_classes=n_classes)
logits_premodel = ArcFaceLayer(
n_classes,
s=model_spec["arcface"]["s"],
m=model_spec["arcface"]["m"],
arc=False,
)(embeddings, None)
# Wrapping the embedding validation
logits_premodel = pre_model.get_layer("logits").output
# logits_premodel = pre_model.get_layer("logits").output
pre_model = EmbeddingValidation(
pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
......@@ -187,9 +194,11 @@ def create_model(
################################
## Creating the specific models
if "arcface" in model_spec:
labels = tf.keras.layers.Input([], name="label")
logits_arcface = ArcFaceLayer(
n_classes, s=model_spec["arcface"]["s"], m=model_spec["arcface"]["m"]
n_classes,
s=model_spec["arcface"]["s"],
m=model_spec["arcface"]["m"],
arc=True,
)(embeddings, labels)
arc_model = ArcFaceModel(
inputs=(pre_model.input, labels), outputs=[logits_arcface, embeddings]
......@@ -253,7 +262,9 @@ def build_and_compile_models(
optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"],
)
arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
arc_model.compile(
optimizer=optimizer, loss=cross_entropy, run_eagerly=True, metrics=["accuracy"]
)
return pre_model, arc_model
......@@ -365,17 +376,20 @@ def train_and_evaluate(
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
# Train the Cross-entropy model if the case
if pre_train:
# STEPS_PER_EPOCH
pre_model.fit(
train_ds,
epochs=2,
epochs=20,
validation_data=val_ds,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
# Copying the last variable
arc_model.trainable_variables[-1].assign(pre_model.trainable_variables[-1])
# STEPS_PER_EPOCH
# epochs=epochs * KERAS_EPOCH_MULTIPLIER,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment