Commit e57ca8e9 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[py] Small fixes in v1

parent 4eb8a009
Pipeline #45053 failed with stage
in 3 minutes and 25 seconds
......@@ -163,8 +163,8 @@ class InceptionResnetBlock(tf.keras.layers.Layer):
branch_1 = [Conv2D_BN(32 // n, 1, name="Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(32 // n, 3, name="Branch_1/Conv2d_0b_3x3")]
branch_2 = [Conv2D_BN(32 // n, 1, name="Branch_2/Conv2d_0a_1x1")]
branch_2 += [Conv2D_BN(48 // n, 3, name="Branch_2/Conv2d_0b_3x3")]
branch_2 += [Conv2D_BN(64 // n, 3, name="Branch_2/Conv2d_0c_3x3")]
branch_2 += [Conv2D_BN(32 // n, 3, name="Branch_2/Conv2d_0b_3x3")]
branch_2 += [Conv2D_BN(32 // n, 3, name="Branch_2/Conv2d_0c_3x3")]
branches = [branch_0, branch_1, branch_2]
elif block_type == "block17":
branch_0 = [Conv2D_BN(128 // n, 1, name="Branch_0/Conv2d_1x1")]
......@@ -175,8 +175,8 @@ class InceptionResnetBlock(tf.keras.layers.Layer):
elif block_type == "block8":
branch_0 = [Conv2D_BN(192 // n, 1, name="Branch_0/Conv2d_1x1")]
branch_1 = [Conv2D_BN(192 // n, 1, name="Branch_1/Conv2d_0a_1x1")]
branch_1 += [Conv2D_BN(224 // n, (1, 3), name="Branch_1/Conv2d_0b_1x3")]
branch_1 += [Conv2D_BN(256 // n, (3, 1), name="Branch_1/Conv2d_0c_3x1")]
branch_1 += [Conv2D_BN(192 // n, (1, 3), name="Branch_1/Conv2d_0b_1x3")]
branch_1 += [Conv2D_BN(192 // n, (3, 1), name="Branch_1/Conv2d_0c_3x1")]
branches = [branch_0, branch_1]
else:
raise ValueError(
......@@ -307,14 +307,14 @@ class ReductionA(tf.keras.layers.Layer):
config.update(
{
name: getattr(self, name)
for name in ["padding", "k", "kl", "km", "n", "use_atrous", "name"]
for name in ["padding", "k", "l", "m", "n", "use_atrous", "name"]
}
)
return config
class ReductionB(tf.keras.layers.Layer):
"""A Reduction B block for InceptionResnetV2"""
"""A Reduction B block for InceptionResnetV1"""
def __init__(
self,
......@@ -386,64 +386,6 @@ class ReductionB(tf.keras.layers.Layer):
return config
class InceptionA(tf.keras.layers.Layer):
def __init__(self, pool_filters, name="inception_a", **kwargs):
super().__init__(name=name, **kwargs)
self.pool_filters = pool_filters
self.branch1x1 = Conv2D_BN(
96, kernel_size=1, padding="same", name="Branch_0/Conv2d_1x1"
)
self.branch3x3dbl_1 = Conv2D_BN(
64, kernel_size=1, padding="same", name="Branch_2/Conv2d_0a_1x1"
)
self.branch3x3dbl_2 = Conv2D_BN(
96, kernel_size=3, padding="same", name="Branch_2/Conv2d_0b_3x3"
)
self.branch3x3dbl_3 = Conv2D_BN(
96, kernel_size=3, padding="same", name="Branch_2/Conv2d_0c_3x3"
)
self.branch5x5_1 = Conv2D_BN(
48, kernel_size=1, padding="same", name="Branch_1/Conv2d_0a_1x1"
)
self.branch5x5_2 = Conv2D_BN(
64, kernel_size=5, padding="same", name="Branch_1/Conv2d_0b_5x5"
)
self.branch_pool_1 = AvgPool2D(
pool_size=3, strides=1, padding="same", name="Branch_3/AvgPool_0a_3x3"
)
self.branch_pool_2 = Conv2D_BN(
pool_filters, kernel_size=1, padding="same", name="Branch_3/Conv2d_0b_1x1"
)
channel_axis = 1 if K.image_data_format() == "channels_first" else 3
self.concat = Concatenate(axis=channel_axis)
def call(self, inputs, training=None):
branch1x1 = self.branch1x1(inputs)
branch3x3dbl = self.branch3x3dbl_1(inputs)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch5x5 = self.branch5x5_1(inputs)
branch5x5 = self.branch5x5_2(branch5x5)
branch_pool = self.branch_pool_1(inputs)
branch_pool = self.branch_pool_2(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return self.concat(outputs)
def get_config(self):
config = super().get_config()
config.update({"pool_filters": self.pool_filters, "name": self.name})
return config
def InceptionResNetV1(
include_top=True,
input_tensor=None,
......@@ -452,7 +394,7 @@ def InceptionResNetV1(
classes=1000,
bottleneck=False,
dropout_rate=0.2,
name="InceptionResnetV2",
name="InceptionResnetV1",
**kwargs,
):
"""Instantiates the Inception-ResNet v1 architecture.
......@@ -580,8 +522,8 @@ def InceptionResNetV1(
scale=1.0,
activation=None,
block_type="block8",
block_idx=10,
name=f"Mixed_8b",
block_idx=5,
name=f"block8_5",
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment