Commit e7ca90b0 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

use channels last by default

parent 6b189326
import tensorflow as tf
from ..utils import get_available_gpus, to_channels_first
def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
kernerl_size=(3, 3), n_classes=2):
data_format = 'channels_last'
if len(get_available_gpus()) != 0:
# When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance. See
# https://www.tensorflow.org/performance/performance_guide#data_formats
input_layer = to_channels_first('input_layer')
data_format = 'channels_first'
kernerl_size=(3, 3), n_classes=2,
data_format='channels_last'):
# Keep track of all the endpoints
endpoints = {}
......@@ -78,15 +71,20 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
def model_fn(features, labels, mode, params=None, config=None):
"""Model function for CNN."""
data = features['data']
keys = features['key']
params = params or {}
learning_rate = params.get('learning_rate', 1e-5)
kernerl_size = params.get('kernerl_size', (3, 3))
n_classes = params.get('n_classes', 2)
data = features['data']
keys = features['keys']
logits, _ = architecture(
data, mode, kernerl_size=kernerl_size, n_classes=n_classes)
arch_kwargs = {
'kernerl_size': params.get('kernerl_size', None),
'n_classes': params.get('n_classes', None),
'data_format': params.get('data_format', None),
}
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
logits, _ = architecture(data, mode, **arch_kwargs)
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment