diff --git a/bob/learn/tensorflow/models/resnet50_modified.py b/bob/learn/tensorflow/models/resnet50_modified.py index bd2d4bef90ca0e8484337155eec475fbd3a293bb..c6bb4bdfc3407d5eb4b5f1e830129ce8b9be1e84 100644 --- a/bob/learn/tensorflow/models/resnet50_modified.py +++ b/bob/learn/tensorflow/models/resnet50_modified.py @@ -10,10 +10,9 @@ This resnet 50 implementation provides a cleaner version import tensorflow as tf -from tensorflow.keras import layers from tensorflow.keras.regularizers import l2 -from tensorflow.keras.layers import Input, Conv2D, Activation, BatchNormalization -from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Flatten, Dense +from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization +from tensorflow.keras.layers import MaxPooling2D global weight_decay weight_decay = 1e-4 @@ -226,7 +225,7 @@ def resnet50_modified(input_tensor=None, input_shape=None, **kwargs): if input_tensor is None: input_tensor = tf.keras.Input(shape=input_shape) else: - if not K.is_keras_tensor(input_tensor): + if not tf.keras.backend.is_keras_tensor(input_tensor): input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape) bn_axis = 3 @@ -345,7 +344,7 @@ def resnet101_modified(input_tensor=None, input_shape=None, **kwargs): if __name__ == "__main__": input_tensor = tf.keras.layers.InputLayer([112, 112, 3]) - model = resnet_50(input_tensor) + model = resnet50_modified(input_tensor) print(len(model.variables)) print(model.summary())