diff --git a/cnn_training/arcface.py b/cnn_training/arcface.py
index 712cf6dca83019dae9adea274a5622b491a295b9..17b9f508e3283c8f8b4c311c659da27c3b12d460 100644
--- a/cnn_training/arcface.py
+++ b/cnn_training/arcface.py
@@ -52,23 +52,24 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord"
 
 
 Usage:
-    arcface.py <config-yaml> <checkpoint_path> [--pre-train]
+    arcface.py <config-yaml> <checkpoint_path> [--pre-train --pre-train-epochs=<kn>]
     arcface.py -h | --help
 
 Options:
-  -h --help             Show this screen.
-  --pre-train           If set pretrains the CNN with the crossentropy softmax for 2 epochs  
+  -h --help                        Show this screen.
+  --pre-train                      If set pretrains the CNN with the crossentropy softmax for 2 epochs  
+  --pre-train-epochs=<kn>          Number of epochs to pretrain [default: 40]
   arcface.py arcface -h | help
 
 """
 
 import os
 from functools import partial
-
+import numpy as np
 import pkg_resources
 import tensorflow as tf
 from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
-from bob.learn.tensorflow.models import resnet50v1
+from bob.learn.tensorflow.models import resnet50_modified, resnet101_modified
 from bob.learn.tensorflow.metrics import predict_using_tensors
 from tensorflow.keras import layers
 from bob.learn.tensorflow.callbacks import add_backup_callback
@@ -100,11 +101,11 @@ BACKBONES = dict()
 BACKBONES["inception-resnet-v2"] = InceptionResNetV2
 BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0
 BACKBONES["resnet50"] = tf.keras.applications.ResNet50
+BACKBONES["resnet50_modified"] = resnet50_modified
+BACKBONES["resnet101_modified"] = resnet101_modified
 BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2
 # from bob.learn.tensorflow.models.lenet5 import LeNet5_simplified
 
-BACKBONES["resnet50v1"] = resnet50v1
-
 ##############################
 # SOLVER SPECIFICATIONS
 ##############################
@@ -134,7 +135,8 @@ DATA_SHAPES = dict()
 
 # Inputs with 182x182 are cropped to 160x160
 DATA_SHAPES[182] = 160
-DATA_SHAPES[112] = 98
+DATA_SHAPES[112] = 112
+# DATA_SHAPES[112] = 98
 DATA_SHAPES[126] = 112
 
 
@@ -157,7 +159,6 @@ VALIDATION_BATCH_SIZE = 38
 def create_model(
     n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape, pre_train
 ):
-
     if backbone == "inception-resnet-v2":
         pre_model = BACKBONES[backbone](
             include_top=False, bottleneck=False, input_shape=input_shape,
@@ -177,12 +178,9 @@ def create_model(
 
     if pre_train:
         # pre_model = add_top(pre_model, n_classes=n_classes)
-        logits_premodel = ArcFaceLayer(
-            n_classes,
-            s=model_spec["arcface"]["s"],
-            m=model_spec["arcface"]["m"],
-            arc=False,
-        )(embeddings, None)
+        logits_premodel = ArcFaceLayer(n_classes, s=0, m=0, arc=False,)(
+            embeddings, None
+        )
 
         # Wrapping the embedding validation
         # logits_premodel = pre_model.get_layer("logits").output
@@ -284,6 +282,7 @@ def train_and_evaluate(
     validation_path,
     lerning_rate_schedule,
     pre_train=False,
+    pre_train_epochs=30,
 ):
 
     # number of training steps to do before validating a model. This also defines an epoch
@@ -343,12 +342,15 @@ def train_and_evaluate(
         # Tracking in the tensorboard
         tf.summary.scalar("learning rate", data=lr, step=epoch)
 
-        if epoch in range(200):
+        if epoch in range(40):
             return 1 * lr
-        elif epoch < 1000:
-            return lr * np.exp(-0.005)
-        else:
+        elif epoch < 300:
+            # return lr * np.exp(-0.005)
+            return 0.01
+        elif epoch < 1200:
             return 0.0001
+        else:
+            return 0.00001
 
     if lerning_rate_schedule == "cosine-decay-restarts":
         decay_steps = 50
@@ -381,7 +383,7 @@ def train_and_evaluate(
         # STEPS_PER_EPOCH
         pre_model.fit(
             train_ds,
-            epochs=20,
+            epochs=int(pre_train_epochs),
             validation_data=val_ds,
             steps_per_epoch=STEPS_PER_EPOCH,
             validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
@@ -449,5 +451,6 @@ if __name__ == "__main__":
         if "lerning-rate-schedule" in config
         else None,
         pre_train=args["--pre-train"],
+        pre_train_epochs=args["--pre-train-epochs"],
     )