diff --git a/cnn_training/centerloss.py b/cnn_training/centerloss.py
index 50511ad764624948cc67987053a9f8f9ab9a5372..68bf2c300372c913ce3c47da25f52009fceb7dc1 100644
--- a/cnn_training/centerloss.py
+++ b/cnn_training/centerloss.py
@@ -60,7 +60,6 @@ class CenterLossModel(tf.keras.Model):
         train_cross_entropy,
         train_center_loss,
         test_acc,
-        global_batch_size,
         **kwargs,
     ):
         super().compile(**kwargs)
@@ -71,7 +70,6 @@ class CenterLossModel(tf.keras.Model):
         self.train_cross_entropy = train_cross_entropy
         self.train_center_loss = train_center_loss
         self.test_acc = test_acc
-        self.global_batch_size = global_batch_size
 
     def train_step(self, data):
         images, labels = data
@@ -123,17 +121,16 @@ def create_model(n_classes):
     return model
 
 
-def build_and_compile_model(n_classes, learning_rate, global_batch_size):
+def build_and_compile_model(n_classes, learning_rate):
     model = create_model(n_classes)
 
     cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
-        from_logits=True, name="cross_entropy", reduction=tf.keras.losses.Reduction.NONE
+        from_logits=True, name="cross_entropy"
     )
     center_loss = CenterLoss(
         centers_layer=model.get_layer("centers"),
         alpha=0.9,
         name="center_loss",
-        reduction=tf.keras.losses.Reduction.NONE,
     )
 
     optimizer = tf.keras.optimizers.RMSprop(
@@ -155,7 +152,6 @@ def build_and_compile_model(n_classes, learning_rate, global_batch_size):
         train_cross_entropy=train_cross_entropy,
         train_center_loss=train_center_loss,
         test_acc=test_acc,
-        global_batch_size=global_batch_size,
     )
     return model
 
@@ -209,7 +205,7 @@ def train_and_evaluate(tf_record_paths, checkpoint_path, n_classes, batch_size,
     )
     val_metric_name = "val_accuracy"
 
-    model = build_and_compile_model(n_classes, learning_rate, global_batch_size=batch_size)
+    model = build_and_compile_model(n_classes, learning_rate)
 
 
     def scheduler(epoch, lr):