Skip to content
Snippets Groups Projects
Commit 7a498fd9 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'inception-resenet-change' into 'master'

Replace SequentialLayer with Sequential Model

See merge request !95
parents 92c9880e a61605d2
No related branches found
No related tags found
1 merge request!95Replace SequentialLayer with Sequential Model
Pipeline #51535 passed
...@@ -8,7 +8,6 @@ from tensorflow.keras import backend as K ...@@ -8,7 +8,6 @@ from tensorflow.keras import backend as K
from tensorflow.keras.layers import Activation from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import BatchNormalization from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Concatenate from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import GlobalAvgPool2D from tensorflow.keras.layers import GlobalAvgPool2D
...@@ -16,72 +15,11 @@ from tensorflow.keras.layers import GlobalMaxPool2D ...@@ -16,72 +15,11 @@ from tensorflow.keras.layers import GlobalMaxPool2D
from tensorflow.keras.layers import MaxPool2D from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.models import Sequential from tensorflow.keras.models import Sequential
from bob.learn.tensorflow.utils import SequentialLayer from .inception_resnet_v2 import Conv2D_BN
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def Conv2D_BN(
filters,
kernel_size,
strides=1,
padding="same",
activation="relu",
use_bias=False,
name=None,
**kwargs,
):
"""Utility class to apply conv + BN.
# Arguments
x: input tensor.
filters:
kernel_size:
strides:
padding:
activation:
use_bias:
Attributes
----------
activation
activation in `Conv2D`.
filters
filters in `Conv2D`.
kernel_size
kernel size as in `Conv2D`.
padding
padding mode in `Conv2D`.
strides
strides in `Conv2D`.
use_bias
whether to use a bias in `Conv2D`.
name
name of the ops; will become `name + '/Act'` for the activation
and `name + '/BatchNorm'` for the batch norm layer.
"""
layers = [
Conv2D(
filters,
kernel_size,
strides=strides,
padding=padding,
use_bias=use_bias,
name="Conv2D",
)
]
if not use_bias:
bn_axis = 1 if K.image_data_format() == "channels_first" else 3
layers += [BatchNormalization(axis=bn_axis, scale=False, name="BatchNorm")]
if activation is not None:
layers += [Activation(activation, name="Act")]
return SequentialLayer(layers, name=name, **kwargs)
class ScaledResidual(tf.keras.layers.Layer): class ScaledResidual(tf.keras.layers.Layer):
"""A scaled residual connection layer""" """A scaled residual connection layer"""
...@@ -156,24 +94,32 @@ class InceptionResnetBlock(tf.keras.layers.Layer): ...@@ -156,24 +94,32 @@ class InceptionResnetBlock(tf.keras.layers.Layer):
self.n = n self.n = n
if block_type == "block35": if block_type == "block35":
branch_0 = [Conv2D_BN(32 // n, 1, name="Branch_0/Conv2d_1x1")] branch_0 = [Conv2D_BN(32 // n, 1, name=f"{name}/Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(32 // n, 1, name="Branch_1/Conv2d_0a_1x1")] branch_1 = [Conv2D_BN(32 // n, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(32 // n, 3, name="Branch_1/Conv2d_0b_3x3")] branch_1 += [Conv2D_BN(32 // n, 3, name=f"{name}/Branch_1/Conv2d_0b_3x3")]
branch_2 = [Conv2D_BN(32 // n, 1, name="Branch_2/Conv2d_0a_1x1")] branch_2 = [Conv2D_BN(32 // n, 1, name=f"{name}/Branch_2/Conv2d_0a_1x1")]
branch_2 += [Conv2D_BN(32 // n, 3, name="Branch_2/Conv2d_0b_3x3")] branch_2 += [Conv2D_BN(32 // n, 3, name=f"{name}/Branch_2/Conv2d_0b_3x3")]
branch_2 += [Conv2D_BN(32 // n, 3, name="Branch_2/Conv2d_0c_3x3")] branch_2 += [Conv2D_BN(32 // n, 3, name=f"{name}/Branch_2/Conv2d_0c_3x3")]
branches = [branch_0, branch_1, branch_2] branches = [branch_0, branch_1, branch_2]
elif block_type == "block17": elif block_type == "block17":
branch_0 = [Conv2D_BN(128 // n, 1, name="Branch_0/Conv2d_1x1")] branch_0 = [Conv2D_BN(128 // n, 1, name=f"{name}/Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(128 // n, 1, name="Branch_1/Conv2d_0a_1x1")] branch_1 = [Conv2D_BN(128 // n, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(128 // n, (1, 7), name="Branch_1/Conv2d_0b_1x7")] branch_1 += [
branch_1 += [Conv2D_BN(128 // n, (7, 1), name="Branch_1/Conv2d_0c_7x1")] Conv2D_BN(128 // n, (1, 7), name=f"{name}/Branch_1/Conv2d_0b_1x7")
]
branch_1 += [
Conv2D_BN(128 // n, (7, 1), name=f"{name}/Branch_1/Conv2d_0c_7x1")
]
branches = [branch_0, branch_1] branches = [branch_0, branch_1]
elif block_type == "block8": elif block_type == "block8":
branch_0 = [Conv2D_BN(192 // n, 1, name="Branch_0/Conv2d_1x1")] branch_0 = [Conv2D_BN(192 // n, 1, name=f"{name}/Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(192 // n, 1, name="Branch_1/Conv2d_0a_1x1")] branch_1 = [Conv2D_BN(192 // n, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(192 // n, (1, 3), name="Branch_1/Conv2d_0b_1x3")] branch_1 += [
branch_1 += [Conv2D_BN(192 // n, (3, 1), name="Branch_1/Conv2d_0c_3x1")] Conv2D_BN(192 // n, (1, 3), name=f"{name}/Branch_1/Conv2d_0b_1x3")
]
branch_1 += [
Conv2D_BN(192 // n, (3, 1), name=f"{name}/Branch_1/Conv2d_0c_3x1")
]
branches = [branch_0, branch_1] branches = [branch_0, branch_1]
else: else:
raise ValueError( raise ValueError(
...@@ -185,15 +131,15 @@ class InceptionResnetBlock(tf.keras.layers.Layer): ...@@ -185,15 +131,15 @@ class InceptionResnetBlock(tf.keras.layers.Layer):
self.branches = branches self.branches = branches
channel_axis = 1 if K.image_data_format() == "channels_first" else 3 channel_axis = 1 if K.image_data_format() == "channels_first" else 3
self.concat = Concatenate(axis=channel_axis, name="concatenate") self.concat = Concatenate(axis=channel_axis, name=f"{name}/concatenate")
self.up_conv = Conv2D_BN( self.up_conv = Conv2D_BN(
n_channels, 1, activation=None, use_bias=True, name="Conv2d_1x1" n_channels, 1, activation=None, use_bias=True, name=f"{name}/Conv2d_1x1"
) )
self.residual = ScaledResidual(scale) self.residual = ScaledResidual(scale)
self.act = lambda x: x self.act = lambda x: x
if activation is not None: if activation is not None:
self.act = Activation(activation, name="act") self.act = Activation(activation, name=f"{name}/act")
def call(self, inputs, training=None): def call(self, inputs, training=None):
branch_outputs = [] branch_outputs = []
...@@ -258,19 +204,19 @@ class ReductionA(tf.keras.layers.Layer): ...@@ -258,19 +204,19 @@ class ReductionA(tf.keras.layers.Layer):
3, 3,
strides=1 if use_atrous else 2, strides=1 if use_atrous else 2,
padding=padding, padding=padding,
name="Branch_0/Conv2d_1a_3x3", name=f"{name}/Branch_0/Conv2d_1a_3x3",
) )
] ]
branch_1 = [ branch_1 = [
Conv2D_BN(k, 1, name="Branch_1/Conv2d_0a_1x1"), Conv2D_BN(k, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1"),
Conv2D_BN(l, 3, name="Branch_1/Conv2d_0b_3x3"), Conv2D_BN(l, 3, name=f"{name}/Branch_1/Conv2d_0b_3x3"),
Conv2D_BN( Conv2D_BN(
m, m,
3, 3,
strides=1 if use_atrous else 2, strides=1 if use_atrous else 2,
padding=padding, padding=padding,
name="Branch_1/Conv2d_1a_3x3", name=f"{name}/Branch_1/Conv2d_1a_3x3",
), ),
] ]
...@@ -279,7 +225,7 @@ class ReductionA(tf.keras.layers.Layer): ...@@ -279,7 +225,7 @@ class ReductionA(tf.keras.layers.Layer):
3, 3,
strides=1 if use_atrous else 2, strides=1 if use_atrous else 2,
padding=padding, padding=padding,
name="Branch_2/MaxPool_1a_3x3", name=f"{name}/Branch_2/MaxPool_1a_3x3",
) )
] ]
self.branches = [branch_0, branch_1, branch_pool] self.branches = [branch_0, branch_1, branch_pool]
...@@ -337,23 +283,31 @@ class ReductionB(tf.keras.layers.Layer): ...@@ -337,23 +283,31 @@ class ReductionB(tf.keras.layers.Layer):
self.pq = pq self.pq = pq
branch_1 = [ branch_1 = [
Conv2D_BN(n, 1, name="Branch_0/Conv2d_0a_1x1"), Conv2D_BN(n, 1, name=f"{name}/Branch_0/Conv2d_0a_1x1"),
Conv2D_BN(no, 3, strides=2, padding=padding, name="Branch_0/Conv2d_1a_3x3"), Conv2D_BN(
no, 3, strides=2, padding=padding, name=f"{name}/Branch_0/Conv2d_1a_3x3"
),
] ]
branch_2 = [ branch_2 = [
Conv2D_BN(p, 1, name="Branch_1/Conv2d_0a_1x1"), Conv2D_BN(p, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1"),
Conv2D_BN(pq, 3, strides=2, padding=padding, name="Branch_1/Conv2d_1a_3x3"), Conv2D_BN(
pq, 3, strides=2, padding=padding, name=f"{name}/Branch_1/Conv2d_1a_3x3"
),
] ]
branch_3 = [ branch_3 = [
Conv2D_BN(k, 1, name="Branch_2/Conv2d_0a_1x1"), Conv2D_BN(k, 1, name=f"{name}/Branch_2/Conv2d_0a_1x1"),
Conv2D_BN(kl, 3, name="Branch_2/Conv2d_0b_3x3"), Conv2D_BN(kl, 3, name=f"{name}/Branch_2/Conv2d_0b_3x3"),
Conv2D_BN(km, 3, strides=2, padding=padding, name="Branch_2/Conv2d_1a_3x3"), Conv2D_BN(
km, 3, strides=2, padding=padding, name=f"{name}/Branch_2/Conv2d_1a_3x3"
),
] ]
branch_pool = [ branch_pool = [
MaxPool2D(3, strides=2, padding=padding, name="Branch_3/MaxPool_1a_3x3") MaxPool2D(
3, strides=2, padding=padding, name=f"{name}/Branch_3/MaxPool_1a_3x3"
)
] ]
self.branches = [branch_1, branch_2, branch_3, branch_pool] self.branches = [branch_1, branch_2, branch_3, branch_pool]
channel_axis = 1 if K.image_data_format() == "channels_first" else 3 channel_axis = 1 if K.image_data_format() == "channels_first" else 3
......
...@@ -19,8 +19,6 @@ from tensorflow.keras.layers import MaxPool2D ...@@ -19,8 +19,6 @@ from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.models import Model from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential from tensorflow.keras.models import Sequential
from ..utils import SequentialLayer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,6 +61,8 @@ def Conv2D_BN( ...@@ -63,6 +61,8 @@ def Conv2D_BN(
name of the ops; will become `name + '/Act'` for the activation name of the ops; will become `name + '/Act'` for the activation
and `name + '/BatchNorm'` for the batch norm layer. and `name + '/BatchNorm'` for the batch norm layer.
""" """
if name is None:
raise ValueError("name cannot be None!")
layers = [ layers = [
Conv2D( Conv2D(
...@@ -71,18 +71,20 @@ def Conv2D_BN( ...@@ -71,18 +71,20 @@ def Conv2D_BN(
strides=strides, strides=strides,
padding=padding, padding=padding,
use_bias=use_bias, use_bias=use_bias,
name="Conv2D", name=f"{name}/Conv2D",
) )
] ]
if not use_bias: if not use_bias:
bn_axis = 1 if K.image_data_format() == "channels_first" else 3 bn_axis = 1 if K.image_data_format() == "channels_first" else 3
layers += [BatchNormalization(axis=bn_axis, scale=False, name="BatchNorm")] layers += [
BatchNormalization(axis=bn_axis, scale=False, name=f"{name}/BatchNorm")
]
if activation is not None: if activation is not None:
layers += [Activation(activation, name="Act")] layers += [Activation(activation, name=f"{name}/Act")]
return SequentialLayer(layers, name=name, **kwargs) return tf.keras.Sequential(layers, name=name, **kwargs)
class ScaledResidual(tf.keras.layers.Layer): class ScaledResidual(tf.keras.layers.Layer):
...@@ -159,24 +161,32 @@ class InceptionResnetBlock(tf.keras.layers.Layer): ...@@ -159,24 +161,32 @@ class InceptionResnetBlock(tf.keras.layers.Layer):
self.n = n self.n = n
if block_type == "block35": if block_type == "block35":
branch_0 = [Conv2D_BN(32 // n, 1, name="Branch_0/Conv2d_1x1")] branch_0 = [Conv2D_BN(32 // n, 1, name=f"{name}/Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(32 // n, 1, name="Branch_1/Conv2d_0a_1x1")] branch_1 = [Conv2D_BN(32 // n, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(32 // n, 3, name="Branch_1/Conv2d_0b_3x3")] branch_1 += [Conv2D_BN(32 // n, 3, name=f"{name}/Branch_1/Conv2d_0b_3x3")]
branch_2 = [Conv2D_BN(32 // n, 1, name="Branch_2/Conv2d_0a_1x1")] branch_2 = [Conv2D_BN(32 // n, 1, name=f"{name}/Branch_2/Conv2d_0a_1x1")]
branch_2 += [Conv2D_BN(48 // n, 3, name="Branch_2/Conv2d_0b_3x3")] branch_2 += [Conv2D_BN(48 // n, 3, name=f"{name}/Branch_2/Conv2d_0b_3x3")]
branch_2 += [Conv2D_BN(64 // n, 3, name="Branch_2/Conv2d_0c_3x3")] branch_2 += [Conv2D_BN(64 // n, 3, name=f"{name}/Branch_2/Conv2d_0c_3x3")]
branches = [branch_0, branch_1, branch_2] branches = [branch_0, branch_1, branch_2]
elif block_type == "block17": elif block_type == "block17":
branch_0 = [Conv2D_BN(192 // n, 1, name="Branch_0/Conv2d_1x1")] branch_0 = [Conv2D_BN(192 // n, 1, name=f"{name}/Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(128 // n, 1, name="Branch_1/Conv2d_0a_1x1")] branch_1 = [Conv2D_BN(128 // n, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(160 // n, (1, 7), name="Branch_1/Conv2d_0b_1x7")] branch_1 += [
branch_1 += [Conv2D_BN(192 // n, (7, 1), name="Branch_1/Conv2d_0c_7x1")] Conv2D_BN(160 // n, (1, 7), name=f"{name}/Branch_1/Conv2d_0b_1x7")
]
branch_1 += [
Conv2D_BN(192 // n, (7, 1), name=f"{name}/Branch_1/Conv2d_0c_7x1")
]
branches = [branch_0, branch_1] branches = [branch_0, branch_1]
elif block_type == "block8": elif block_type == "block8":
branch_0 = [Conv2D_BN(192 // n, 1, name="Branch_0/Conv2d_1x1")] branch_0 = [Conv2D_BN(192 // n, 1, name=f"{name}/Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(192 // n, 1, name="Branch_1/Conv2d_0a_1x1")] branch_1 = [Conv2D_BN(192 // n, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(224 // n, (1, 3), name="Branch_1/Conv2d_0b_1x3")] branch_1 += [
branch_1 += [Conv2D_BN(256 // n, (3, 1), name="Branch_1/Conv2d_0c_3x1")] Conv2D_BN(224 // n, (1, 3), name=f"{name}/Branch_1/Conv2d_0b_1x3")
]
branch_1 += [
Conv2D_BN(256 // n, (3, 1), name=f"{name}/Branch_1/Conv2d_0c_3x1")
]
branches = [branch_0, branch_1] branches = [branch_0, branch_1]
else: else:
raise ValueError( raise ValueError(
...@@ -188,15 +198,15 @@ class InceptionResnetBlock(tf.keras.layers.Layer): ...@@ -188,15 +198,15 @@ class InceptionResnetBlock(tf.keras.layers.Layer):
self.branches = branches self.branches = branches
channel_axis = 1 if K.image_data_format() == "channels_first" else 3 channel_axis = 1 if K.image_data_format() == "channels_first" else 3
self.concat = Concatenate(axis=channel_axis, name="concatenate") self.concat = Concatenate(axis=channel_axis, name=f"{name}/concatenate")
self.up_conv = Conv2D_BN( self.up_conv = Conv2D_BN(
n_channels, 1, activation=None, use_bias=True, name="Conv2d_1x1" n_channels, 1, activation=None, use_bias=True, name=f"{name}/Conv2d_1x1"
) )
self.residual = ScaledResidual(scale) self.residual = ScaledResidual(scale)
self.act = lambda x: x self.act = lambda x: x
if activation is not None: if activation is not None:
self.act = Activation(activation, name="act") self.act = Activation(activation, name=f"{name}/act")
def call(self, inputs, training=None): def call(self, inputs, training=None):
branch_outputs = [] branch_outputs = []
...@@ -261,19 +271,19 @@ class ReductionA(tf.keras.layers.Layer): ...@@ -261,19 +271,19 @@ class ReductionA(tf.keras.layers.Layer):
3, 3,
strides=1 if use_atrous else 2, strides=1 if use_atrous else 2,
padding=padding, padding=padding,
name="Branch_0/Conv2d_1a_3x3", name=f"{name}/Branch_0/Conv2d_1a_3x3",
) )
] ]
branch_2 = [ branch_2 = [
Conv2D_BN(k, 1, name="Branch_1/Conv2d_0a_1x1"), Conv2D_BN(k, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1"),
Conv2D_BN(kl, 3, name="Branch_1/Conv2d_0b_3x3"), Conv2D_BN(kl, 3, name=f"{name}/Branch_1/Conv2d_0b_3x3"),
Conv2D_BN( Conv2D_BN(
km, km,
3, 3,
strides=1 if use_atrous else 2, strides=1 if use_atrous else 2,
padding=padding, padding=padding,
name="Branch_1/Conv2d_1a_3x3", name=f"{name}/Branch_1/Conv2d_1a_3x3",
), ),
] ]
...@@ -282,7 +292,7 @@ class ReductionA(tf.keras.layers.Layer): ...@@ -282,7 +292,7 @@ class ReductionA(tf.keras.layers.Layer):
3, 3,
strides=1 if use_atrous else 2, strides=1 if use_atrous else 2,
padding=padding, padding=padding,
name="Branch_2/MaxPool_1a_3x3", name=f"{name}/Branch_2/MaxPool_1a_3x3",
) )
] ]
self.branches = [branch_1, branch_2, branch_pool] self.branches = [branch_1, branch_2, branch_pool]
...@@ -340,23 +350,31 @@ class ReductionB(tf.keras.layers.Layer): ...@@ -340,23 +350,31 @@ class ReductionB(tf.keras.layers.Layer):
self.pq = pq self.pq = pq
branch_1 = [ branch_1 = [
Conv2D_BN(n, 1, name="Branch_0/Conv2d_0a_1x1"), Conv2D_BN(n, 1, name=f"{name}/Branch_0/Conv2d_0a_1x1"),
Conv2D_BN(no, 3, strides=2, padding=padding, name="Branch_0/Conv2d_1a_3x3"), Conv2D_BN(
no, 3, strides=2, padding=padding, name=f"{name}/Branch_0/Conv2d_1a_3x3"
),
] ]
branch_2 = [ branch_2 = [
Conv2D_BN(p, 1, name="Branch_1/Conv2d_0a_1x1"), Conv2D_BN(p, 1, name=f"{name}/Branch_1/Conv2d_0a_1x1"),
Conv2D_BN(pq, 3, strides=2, padding=padding, name="Branch_1/Conv2d_1a_3x3"), Conv2D_BN(
pq, 3, strides=2, padding=padding, name=f"{name}/Branch_1/Conv2d_1a_3x3"
),
] ]
branch_3 = [ branch_3 = [
Conv2D_BN(k, 1, name="Branch_2/Conv2d_0a_1x1"), Conv2D_BN(k, 1, name=f"{name}/Branch_2/Conv2d_0a_1x1"),
Conv2D_BN(kl, 3, name="Branch_2/Conv2d_0b_3x3"), Conv2D_BN(kl, 3, name=f"{name}/Branch_2/Conv2d_0b_3x3"),
Conv2D_BN(km, 3, strides=2, padding=padding, name="Branch_2/Conv2d_1a_3x3"), Conv2D_BN(
km, 3, strides=2, padding=padding, name=f"{name}/Branch_2/Conv2d_1a_3x3"
),
] ]
branch_pool = [ branch_pool = [
MaxPool2D(3, strides=2, padding=padding, name="Branch_3/MaxPool_1a_3x3") MaxPool2D(
3, strides=2, padding=padding, name=f"{name}/Branch_3/MaxPool_1a_3x3"
)
] ]
self.branches = [branch_1, branch_2, branch_3, branch_pool] self.branches = [branch_1, branch_2, branch_3, branch_pool]
channel_axis = 1 if K.image_data_format() == "channels_first" else 3 channel_axis = 1 if K.image_data_format() == "channels_first" else 3
...@@ -392,31 +410,37 @@ class InceptionA(tf.keras.layers.Layer): ...@@ -392,31 +410,37 @@ class InceptionA(tf.keras.layers.Layer):
self.pool_filters = pool_filters self.pool_filters = pool_filters
self.branch1x1 = Conv2D_BN( self.branch1x1 = Conv2D_BN(
96, kernel_size=1, padding="same", name="Branch_0/Conv2d_1x1" 96, kernel_size=1, padding="same", name=f"{name}/Branch_0/Conv2d_1x1"
) )
self.branch3x3dbl_1 = Conv2D_BN( self.branch3x3dbl_1 = Conv2D_BN(
64, kernel_size=1, padding="same", name="Branch_2/Conv2d_0a_1x1" 64, kernel_size=1, padding="same", name=f"{name}/Branch_2/Conv2d_0a_1x1"
) )
self.branch3x3dbl_2 = Conv2D_BN( self.branch3x3dbl_2 = Conv2D_BN(
96, kernel_size=3, padding="same", name="Branch_2/Conv2d_0b_3x3" 96, kernel_size=3, padding="same", name=f"{name}/Branch_2/Conv2d_0b_3x3"
) )
self.branch3x3dbl_3 = Conv2D_BN( self.branch3x3dbl_3 = Conv2D_BN(
96, kernel_size=3, padding="same", name="Branch_2/Conv2d_0c_3x3" 96, kernel_size=3, padding="same", name=f"{name}/Branch_2/Conv2d_0c_3x3"
) )
self.branch5x5_1 = Conv2D_BN( self.branch5x5_1 = Conv2D_BN(
48, kernel_size=1, padding="same", name="Branch_1/Conv2d_0a_1x1" 48, kernel_size=1, padding="same", name=f"{name}/Branch_1/Conv2d_0a_1x1"
) )
self.branch5x5_2 = Conv2D_BN( self.branch5x5_2 = Conv2D_BN(
64, kernel_size=5, padding="same", name="Branch_1/Conv2d_0b_5x5" 64, kernel_size=5, padding="same", name=f"{name}/Branch_1/Conv2d_0b_5x5"
) )
self.branch_pool_1 = AvgPool2D( self.branch_pool_1 = AvgPool2D(
pool_size=3, strides=1, padding="same", name="Branch_3/AvgPool_0a_3x3" pool_size=3,
strides=1,
padding="same",
name=f"{name}/Branch_3/AvgPool_0a_3x3",
) )
self.branch_pool_2 = Conv2D_BN( self.branch_pool_2 = Conv2D_BN(
pool_filters, kernel_size=1, padding="same", name="Branch_3/Conv2d_0b_1x1" pool_filters,
kernel_size=1,
padding="same",
name=f"{name}/Branch_3/Conv2d_0b_1x1",
) )
channel_axis = 1 if K.image_data_format() == "channels_first" else 3 channel_axis = 1 if K.image_data_format() == "channels_first" else 3
......
import copy
import logging import logging
import tensorflow as tf import tensorflow as tf
import tensorflow.keras.backend as K import tensorflow.keras.backend as K
from tensorflow.python.keras import layers as layer_module
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -17,67 +13,6 @@ SINGLE_LAYER_OUTPUT_ERROR_MSG = ( ...@@ -17,67 +13,6 @@ SINGLE_LAYER_OUTPUT_ERROR_MSG = (
) )
class SequentialLayer(tf.keras.layers.Layer):
"""A Layer that does the same thing as tf.keras.Sequential but
its variables can be scoped.
Parameters
----------
layers : list
List of layers. All layers must be provided at initialization time
"""
def __init__(self, layers, **kwargs):
super().__init__(**kwargs)
self.sequential_layers = list(layers)
def call(self, inputs, training=None, mask=None):
outputs = inputs
for layer in self.sequential_layers:
# During each iteration, `inputs` are the inputs to `layer`, and `outputs`
# are the outputs of `layer` applied to `inputs`. At the end of each
# iteration `inputs` is set to `outputs` to prepare for the next layer.
kwargs = {}
argspec = tf_inspect.getfullargspec(layer.call).args
if "mask" in argspec:
kwargs["mask"] = mask
if "training" in argspec:
kwargs["training"] = training
outputs = layer(outputs, **kwargs)
if len(nest.flatten(outputs)) != 1:
raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)
mask = getattr(outputs, "_keras_mask", None)
return outputs
def get_config(self):
layer_configs = []
for layer in self.sequential_layers:
layer_configs.append(generic_utils.serialize_keras_object(layer))
config = {"name": self.name, "layers": copy.deepcopy(layer_configs)}
return config
@classmethod
def from_config(cls, config, custom_objects=None):
if "name" in config:
name = config["name"]
layer_configs = config["layers"]
else:
name = None
layer_configs = config
layers = []
for layer_config in layer_configs:
layer = layer_module.deserialize(
layer_config, custom_objects=custom_objects
)
layers.append(layer)
model = cls(layers, name=name)
return model
def keras_channels_index(): def keras_channels_index():
return -3 if K.image_data_format() == "channels_first" else -1 return -3 if K.image_data_format() == "channels_first" else -1
......
...@@ -52,7 +52,6 @@ Keras Utilities ...@@ -52,7 +52,6 @@ Keras Utilities
=============== ===============
.. autosummary:: .. autosummary::
bob.learn.tensorflow.utils.keras.SequentialLayer
bob.learn.tensorflow.utils.keras.keras_channels_index bob.learn.tensorflow.utils.keras.keras_channels_index
bob.learn.tensorflow.utils.keras.keras_model_weights_as_initializers_for_variables bob.learn.tensorflow.utils.keras.keras_model_weights_as_initializers_for_variables
bob.learn.tensorflow.utils.keras.restore_model_variables_from_checkpoint bob.learn.tensorflow.utils.keras.restore_model_variables_from_checkpoint
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment