Skip to content
Snippets Groups Projects
Commit 63993d46 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

nitpick

parent d957c74a
No related branches found
No related tags found
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
...@@ -32,11 +32,8 @@ class ConvEncoder(tf.keras.Model): ...@@ -32,11 +32,8 @@ class ConvEncoder(tf.keras.Model):
l2_kw = get_l2_kw(weight_decay) l2_kw = get_l2_kw(weight_decay)
layers = [] layers = []
for i, (filters, kernel_size, strides, padding) in enumerate(encoder_layers): for i, (filters, kernel_size, strides, padding) in enumerate(encoder_layers):
pad_kw = {}
# if i == 0:
# pad_kw["input_shape"] = input_shape
pad = tf.keras.layers.ZeroPadding2D( pad = tf.keras.layers.ZeroPadding2D(
padding=padding, data_format=data_format, name=f"pad_{i}", **pad_kw padding=padding, data_format=data_format, name=f"pad_{i}"
) )
conv = tf.keras.layers.Conv2D( conv = tf.keras.layers.Conv2D(
filters, filters,
...@@ -75,17 +72,13 @@ class ConvDecoder(tf.keras.Model): ...@@ -75,17 +72,13 @@ class ConvDecoder(tf.keras.Model):
l2_kw = get_l2_kw(weight_decay) l2_kw = get_l2_kw(weight_decay)
layers = [] layers = []
for i, (filters, kernel_size, strides, cropping) in enumerate(decoder_layers): for i, (filters, kernel_size, strides, cropping) in enumerate(decoder_layers):
dconv_kw = {}
dconv_kw.update(l2_kw)
# if i == 0:
# dconv_kw["input_shape"] = embedding_shape
dconv = tf.keras.layers.Conv2DTranspose( dconv = tf.keras.layers.Conv2DTranspose(
filters, filters,
kernel_size, kernel_size,
strides=strides, strides=strides,
data_format=data_format, data_format=data_format,
name=f"dconv_{i}", name=f"dconv_{i}",
**dconv_kw, **l2_kw,
) )
crop = tf.keras.layers.Cropping2D( crop = tf.keras.layers.Cropping2D(
cropping=cropping, data_format=data_format, name=f"crop_{i}" cropping=cropping, data_format=data_format, name=f"crop_{i}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment