Update some problems in CenterLoss

parent 142371cc
Pipeline #44635 failed with stage
in 26 minutes and 30 seconds
......@@ -60,7 +60,6 @@ class CenterLossModel(tf.keras.Model):
train_cross_entropy,
train_center_loss,
test_acc,
global_batch_size,
**kwargs,
):
super().compile(**kwargs)
......@@ -71,7 +70,6 @@ class CenterLossModel(tf.keras.Model):
self.train_cross_entropy = train_cross_entropy
self.train_center_loss = train_center_loss
self.test_acc = test_acc
self.global_batch_size = global_batch_size
def train_step(self, data):
images, labels = data
......@@ -123,17 +121,16 @@ def create_model(n_classes):
return model
def build_and_compile_model(n_classes, learning_rate, global_batch_size):
def build_and_compile_model(n_classes, learning_rate):
model = create_model(n_classes)
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy", reduction=tf.keras.losses.Reduction.NONE
from_logits=True, name="cross_entropy"
)
center_loss = CenterLoss(
centers_layer=model.get_layer("centers"),
alpha=0.9,
name="center_loss",
reduction=tf.keras.losses.Reduction.NONE,
)
optimizer = tf.keras.optimizers.RMSprop(
......@@ -155,7 +152,6 @@ def build_and_compile_model(n_classes, learning_rate, global_batch_size):
train_cross_entropy=train_cross_entropy,
train_center_loss=train_center_loss,
test_acc=test_acc,
global_batch_size=global_batch_size,
)
return model
......@@ -209,7 +205,7 @@ def train_and_evaluate(tf_record_paths, checkpoint_path, n_classes, batch_size,
)
val_metric_name = "val_accuracy"
model = build_and_compile_model(n_classes, learning_rate, global_batch_size=batch_size)
model = build_and_compile_model(n_classes, learning_rate)
def scheduler(epoch, lr):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment