Skip to content
Snippets Groups Projects
Commit 750171bb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed some bugs

parent feff4f40
No related branches found
No related tags found
1 merge request!94Properly implemented resnet50 and resnet101
Pipeline #50090 failed
...@@ -10,10 +10,9 @@ This resnet 50 implementation provides a cleaner version ...@@ -10,10 +10,9 @@ This resnet 50 implementation provides a cleaner version
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2 from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Input, Conv2D, Activation, BatchNormalization from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Flatten, Dense from tensorflow.keras.layers import MaxPooling2D
global weight_decay global weight_decay
weight_decay = 1e-4 weight_decay = 1e-4
...@@ -226,7 +225,7 @@ def resnet50_modified(input_tensor=None, input_shape=None, **kwargs): ...@@ -226,7 +225,7 @@ def resnet50_modified(input_tensor=None, input_shape=None, **kwargs):
if input_tensor is None: if input_tensor is None:
input_tensor = tf.keras.Input(shape=input_shape) input_tensor = tf.keras.Input(shape=input_shape)
else: else:
if not K.is_keras_tensor(input_tensor): if not tf.keras.backend.is_keras_tensor(input_tensor):
input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape) input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape)
bn_axis = 3 bn_axis = 3
...@@ -345,7 +344,7 @@ def resnet101_modified(input_tensor=None, input_shape=None, **kwargs): ...@@ -345,7 +344,7 @@ def resnet101_modified(input_tensor=None, input_shape=None, **kwargs):
if __name__ == "__main__": if __name__ == "__main__":
input_tensor = tf.keras.layers.InputLayer([112, 112, 3]) input_tensor = tf.keras.layers.InputLayer([112, 112, 3])
model = resnet_50(input_tensor) model = resnet50_modified(input_tensor)
print(len(model.variables)) print(len(model.variables))
print(model.summary()) print(model.summary())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment