diff --git a/cnn_training/centerloss.py b/cnn_training/centerloss.py index 50511ad764624948cc67987053a9f8f9ab9a5372..68bf2c300372c913ce3c47da25f52009fceb7dc1 100644 --- a/cnn_training/centerloss.py +++ b/cnn_training/centerloss.py @@ -60,7 +60,6 @@ class CenterLossModel(tf.keras.Model): train_cross_entropy, train_center_loss, test_acc, - global_batch_size, **kwargs, ): super().compile(**kwargs) @@ -71,7 +70,6 @@ class CenterLossModel(tf.keras.Model): self.train_cross_entropy = train_cross_entropy self.train_center_loss = train_center_loss self.test_acc = test_acc - self.global_batch_size = global_batch_size def train_step(self, data): images, labels = data @@ -123,17 +121,16 @@ def create_model(n_classes): return model -def build_and_compile_model(n_classes, learning_rate, global_batch_size): +def build_and_compile_model(n_classes, learning_rate): model = create_model(n_classes) cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True, name="cross_entropy", reduction=tf.keras.losses.Reduction.NONE + from_logits=True, name="cross_entropy" ) center_loss = CenterLoss( centers_layer=model.get_layer("centers"), alpha=0.9, name="center_loss", - reduction=tf.keras.losses.Reduction.NONE, ) optimizer = tf.keras.optimizers.RMSprop( @@ -155,7 +152,6 @@ def build_and_compile_model(n_classes, learning_rate, global_batch_size): train_cross_entropy=train_cross_entropy, train_center_loss=train_center_loss, test_acc=test_acc, - global_batch_size=global_batch_size, ) return model @@ -209,7 +205,7 @@ def train_and_evaluate(tf_record_paths, checkpoint_path, n_classes, batch_size, ) val_metric_name = "val_accuracy" - model = build_and_compile_model(n_classes, learning_rate, global_batch_size=batch_size) + model = build_and_compile_model(n_classes, learning_rate) def scheduler(epoch, lr):