Skip to content
Snippets Groups Projects
Commit 8ed34b1f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add docs and examples

parent ee013d0d
No related branches found
No related tags found
1 merge request!87WIP: Updates
Pipeline #43979 failed
......@@ -42,8 +42,8 @@ def datasets_to_tfrecords(dataset, output, force, **kwargs):
You can convert the written TFRecord files back to datasets using
:any:`bob.learn.tensorflow.dataset.tfrecords.dataset_from_tfrecord`.
To use this script with SGE, change your dataset and output based on the SGE_TASK_ID
environment variable in your config file.
To use this script with SGE, change your dataset (like shard it) and output a part
of the dataset based on the SGE_TASK_ID environment variable in your config file.
"""
from bob.extension.scripts.click_helper import log_parameters
import os
......
......@@ -3,11 +3,12 @@
.. _bob.learn.tensorflow:
=======================
Tensorflow Bob Bridge
Tensorflow Biometrics
=======================
`bob.learn.tensorflow` is a high-level library, written in Python that runs on top of TensorFlow.
The goal here is to be able to do fast experimentation with neural networks.
This package extends the high-level API of Tensorflow to allow biometrics
experiments. Especially, it provides the tools to train biometrics recognition
and presentation attack detection deep models.
Users Guide
===========
......
.. vim: set fileencoding=utf-8 :
===========
User guide
===========
This package builds on top of tensorflow_. You are expected to have some
familiarity with it before continuing. We recommend reading at least the
following pages:
* https://www.tensorflow.org/get_started
* https://www.tensorflow.org/guide/
* https://www.tensorflow.org/guide/estimators
* https://www.tensorflow.org/guide/datasets
The best way to use tensorflow_ is to use its ``tf.estimator`` and ``tf.data``
API. The estimators are an abstraction API for machine learning models and the
data API is here to help you build complex and efficient input pipelines to
your model. Using the estimators and dataset API of tensorflow will make your
code more complex but instead you will enjoy more efficiency and avoid code
redundancy.
This package builds on top of tensorflow_ (at least 2.3 is needed). You are
expected to have some familiarity with it before continuing. The best way to use
tensorflow_ is to use its ``tf.keras`` and ``tf.data`` API. We recommend reading
at least the following pages:
* https://www.tensorflow.org/tutorials/quickstart/beginner
* https://www.tensorflow.org/tutorials/quickstart/advanced
* https://keras.io/getting_started/intro_to_keras_for_engineers/
* https://keras.io/getting_started/intro_to_keras_for_researchers/
* https://www.tensorflow.org/tutorials/load_data/images
* https://www.tensorflow.org/guide/data
Face recognition example using bob.db databases
===============================================
If you were used to Tensorflow 1 API, then reading these pages are also
recommended:
* https://www.tensorflow.org/guide/effective_tf2
* https://www.tensorflow.org/guide/migrate
* https://www.tensorflow.org/guide/upgrade
* https://github.com/tensorflow/community/blob/master/sigs/testing/faq.md
Let's take a look at a complete example of using a convolutional neural network
(CNN) for recognizing faces from the ATNT database. At the end, we will explain
the data pipeline in more detail.
1. Let's do some imports:
*************************
.. testsetup::
import tempfile
temp_dir = model_dir = tempfile.mkdtemp()
.. doctest::
In the rest of this guide, you will learn a few tips and examples on how to:
>>> from bob.learn.tensorflow.dataset.bio import BioGenerator
>>> from bob.learn.tensorflow.utils import to_channels_last
>>> from bob.learn.tensorflow.estimators import Logits
>>> import bob.db.atnt
>>> import tensorflow as tf
* Port v1 checkpoints to tf v2 format.
* Create datasets and save TFRecords.
* Create models with custom training and evaluation logic.
* Mixed-precision training
* Multi-GPU and multi-worker training
2. Define the inputs:
*********************
After reading this page, you may look at a complete example in:
https://gitlab.idiap.ch/bob/bob.learn.tensorflow/-/blob/master/examples/MSCeleba_centerloss_mixed_precision_multi_worker.py
.. _input_fn:
.. doctest::
Porting V1 Tensorflow checkpoints to V2
=======================================
>>> def input_fn(mode):
... db = bob.db.atnt.Database()
...
... if mode == tf.estimator.ModeKeys.TRAIN:
... groups = 'world'
... elif mode == tf.estimator.ModeKeys.EVAL:
... groups = 'dev'
...
... files = db.objects(groups=groups)
...
... # construct integer labels for each identity in the database
... CLIENT_IDS = (str(f.client_id) for f in files)
... CLIENT_IDS = list(set(CLIENT_IDS))
... CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
...
... def biofile_to_label(f):
... return CLIENT_IDS[str(f.client_id)]
...
... def load_data(database, f):
... img = f.load(database.original_directory, database.original_extension)
... # make a channels_first image (bob format) with 1 channel
... img = img.reshape(1, 112, 92)
... return img
...
... generator = BioGenerator(db, files, load_data, biofile_to_label)
...
... dataset = tf.data.Dataset.from_generator(
... generator, generator.output_types, generator.output_shapes)
...
... def transform(image, label, key):
... # convert to channels last
... image = to_channels_last(image)
...
... # per_image_standardization
... image = tf.image.per_image_standardization(image)
... return (image, label, key)
...
... dataset = dataset.map(transform)
... dataset = dataset.cache(temp_dir)
... if mode == tf.estimator.ModeKeys.TRAIN:
... dataset = dataset.repeat(1)
... dataset = dataset.batch(8)
...
... data, label, key = dataset.make_one_shot_iterator().get_next()
... return {'data': data, 'key': key}, label
...
...
>>> def train_input_fn():
... return input_fn(tf.estimator.ModeKeys.TRAIN)
...
...
>>> def eval_input_fn():
... return input_fn(tf.estimator.ModeKeys.EVAL)
...
...
>>> # supply this hook for debugging
>>> # from tensorflow.python import debug as tf_debug
>>> # hooks = [tf_debug.LocalCLIDebugHook()]
>>> hooks = None
...
>>> train_spec = tf.estimator.TrainSpec(
... input_fn=train_input_fn, max_steps=50, hooks=hooks)
>>> eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
3. Define the architecture:
***************************
Take a look at the notebook located at:
https://gitlab.idiap.ch/bob/bob.learn.tensorflow/-/blob/master/examples/convert_v1_checkpoints_to_v2.ipynb
for an example.
.. doctest::
>>> def architecture(data, mode, **kwargs):
... pass
Creating datasets from data
===========================
If you are working with Bob databases, below is an example of converting them to
``tf.data.Dataset``'s using :any:`bob.learn.tensorflow.data.dataset_using_generator`:
4. Estimator:
************************
.. testsetup::
Explicitly triggering the estimator
...................................
import tempfile
temp_dir = model_dir = tempfile.mkdtemp()
.. doctest::
>>> estimator = Logits(
... architecture,
... optimizer=tf.train.GradientDescentOptimizer(1e-4),
... loss_op=tf.losses.sparse_softmax_cross_entropy,
... n_classes=20, # the number of identities in the world set of ATNT database
... embedding_validation=True,
... validation_batch_size=8,
... model_dir=model_dir,
... )
>>> tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) # doctest: +SKIP
({'accuracy':...
Triggering the estimator via command line
..........................................
In the example above we explicitly triggered the training and validation via
`tf.estimator.train`. We provide command line scripts that does that for you.
Check the command bellow fro training::
$ bob tf train --help
and to evaluate::
$ bob tf eval --help
Data pipeline
=============
There are several ways to provide data to Tensorflow graphs. In this section we
provide some examples on how to make the bridge between `bob.db` databases and
tensorflow `input_fn`.
The Generator input pipeline
*******************************
The :any:`bob.learn.tensorflow.dataset.Generator` class can be used to convert any
database of bob to a ``tf.data.Dataset`` instance.
While building the input pipeline, you can manipulate your data in two
sections:
* In the ``load_data`` function where everything is a numpy array.
* In the ``transform`` function where the data are tensorflow tensors.
For example, you can annotate, crop to bounding box, and scale your images in
the ``load_data`` function and apply transformations on images (e.g. random
crop, mean normalization, random flip, ...) in the ``transform`` function.
Once these transformations are applied on your data, you can easily cache them
to disk (using ``tf.data.Dataset.cache``) for faster reading of data in your
training.
Input pipeline with TFRecords
*****************************
An optimized way to provide data to Tensorflow graphs is using tfrecords. In
this `link <http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/>`_
you have a very nice guide on how TFRecord works.
In `bob.learn.tensorflow` we provide a command line interface
``bob tf db_to_tfrecords`` that converts ``bob.db`` databases to TFRecords.
Type the snippet bellow for help::
>>> import bob.db.atnt
>>> from bob.learn.tensorflow.data import dataset_using_generator
>>> import tensorflow as tf
$ bob tf db_to_tfrecords --help
>>> db = bob.db.atnt.Database()
>>> samples = db.objects(groups="world")
>>> # construct integer labels for each identity in the database
>>> CLIENT_IDS = (str(f.client_id) for f in samples)
>>> CLIENT_IDS = list(set(CLIENT_IDS))
>>> CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
To generate a tfrecord for our
`Face recognition example using bob.db databases`_ example use the following
snippet.
>>> def reader(sample):
... img = sample.load(db.original_directory, db.original_extension)
... label = CLIENT_IDS[str(sample.client_id)]
... return img, label
.. doctest::
>>> dataset = dataset_using_generator(samples, reader)
>>> dataset
>>> from bob.bio.base.utils import read_original_data
>>> from bob.bio.base.test.dummy.database import database # this is based on bob.db.atnt
Create TFRecords from tf.data.Datasets
======================================
>>> groups = 'dev'
Use :any:`bob.learn.tensorflow.data.dataset_to_tfrecord` and
:any:`bob.learn.tensorflow.data.dataset_from_tfrecord` to painlessly convert
**any** ``tf.data.Dataset`` to TFRecords and create datasets back from those
TFRecords:
>>> samples = database.all_files(groups=groups)
>>> from bob.learn.tensorflow.data import dataset_to_tfrecord
>>> from bob.learn.tensorflow.data import dataset_from_tfrecord
>>> path = f"{temp_dir}/my_dataset"
>>> dataset_to_tfrecord(dataset, path)
>>> dataset = dataset_from_tfrecord(path)
>>> dataset
>>> CLIENT_IDS = (str(f.client_id) for f in database.objects(groups=groups))
>>> CLIENT_IDS = set(CLIENT_IDS)
>>> CLIENT_IDS = dict(zip(CLIENT_IDS, range(len(CLIENT_IDS))))
There is also a script called ``bob tf dataset-to-tfrecord`` that wraps the
:any:`bob.learn.tensorflow.data.dataset_to_tfrecord` for easy Grid job
submission.
>>> def file_to_label(f):
... return CLIENT_IDS[str(f.client_id)]
Create models with custom training and evaluation logic
=======================================================
>>> def reader(biofile):
... data = read_original_data(biofile, database.original_directory, database.original_extension)
... label = file_to_label(biofile)
... key = biofile.path
... return (data, label, key)
Training models for biometrics recognition (and metric learning in general) is
different from the typical classification problems since the labels during
training and testing are different. We found that overriding the ``compile``,
``train_step``, and ``test_step`` methods as explained in
https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit is the
best trade-off between the control of what happens during training and
evaluation and writing boilerplate code.
After saving this snippet in a python file (let's say `tfrec.py`) run the
following command ::
Mixed-precision training
========================
When doing mixed precision training: https://www.tensorflow.org/guide/mixed_precision
it is important to scale the loss before computing the gradients.
$ bob tf db_to_tfrecords tfrec.py -o atnt.tfrecord
Once this is done you can replace the `input_fn`_ defined above by the snippet
bellow.
Multi-GPU and multi-worker training
===================================
.. doctest::
It is important that custom metrics and losses do not average their results by the batch
size as the values should be averaged by the global batch size:
https://www.tensorflow.org/tutorials/distribute/custom_training Take a look at custom
metrics and losses in this package for examples of correct implementations.
>>>
>>> from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels_image_augmentation
>>>
>>> tfrecords_filename = ['/path/to/atnt.tfrecord']
>>> data_shape = (112, 92 , 3)
>>> data_type = tf.uint8
>>> batch_size = 16
>>> epochs = 1
>>>
>>> def train_input_fn():
... return shuffle_data_and_labels_image_augmentation(
... tfrecords_filename,
... data_shape,
... data_type,
... batch_size,
... epochs=epochs)
.. testcleanup::
import shutil
shutil.rmtree(model_dir, True)
The Estimator
=============
In this package we have crafted 4 types of estimators.
- Logits: `Cross entropy loss
<https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits>`_
in the hot-encoded layer
:py:class:`bob.learn.tensorflow.estimators.Logits`
- LogitsCenterLoss: `Cross entropy loss
<https://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits>`_
PLUS the `center loss <https://ydwen.github.io/papers/WenECCV16.pdf>`_ in
the hot-encoded layer
:py:class:`bob.learn.tensorflow.estimators.LogitsCenterLoss`
- Siamese: Siamese network estimator
:py:class:`bob.learn.tensorflow.estimators.Siamese`
- Triplet: Triplet network estimator
:py:class:`bob.learn.tensorflow.estimators.Triplet`
.. _tensorflow: https://www.tensorflow.org/
#!/usr/bin/env python
# coding: utf-8
import os
import pickle
from functools import partial
from multiprocessing import cpu_count
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.utils import predict_using_tensors
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import experimental as mixed_precision
from bob.extension import rc
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_policy(policy)
TRAIN_TF_RECORD_PATHS = (
f"{rc['htface']}/databases/tfrecords/msceleba/"
"tfrecord_182x_hand_prunned_44/*.tfrecord"
)
VALIDATION_TF_RECORD_PATHS = (
f"{rc['htface']}/databases/tfrecords/lfw/182x/RGB/*.tfrecord"
)
# there are 2812 samples in the validation set
VALIDATION_SAMPLES = 2812
CHECKPOINT = (
f"{rc['temp']}/models/inception_v2_batchnorm_rgb_msceleba_mixed_precision"
)
AUTOTUNE = tf.data.experimental.AUTOTUNE
TFRECORD_PARALLEL_READ = cpu_count()
N_CLASSES = 87662
DATA_SHAPE = (182, 182, 3) # size of faces
DATA_TYPE = tf.uint8
OUTPUT_SHAPE = (160, 160)
SHUFFLE_BUFFER = int(2e4)
LEARNING_RATE = 0.1
BATCH_SIZE = 90 * 2 # should be a multiple of 8
# we want to run 35 epochs of tfrecords. There are 959083 samples in train tfrecords,
# depending on batch size, steps per epoch, and keras epoch multiplier should change
EPOCHS = 35
# number of training steps to do before validating a model. This also defines an epoch
# for keras which is not really true. We want to evaluate every 180000 (90 * 2000)
# samples
STEPS_PER_EPOCH = 180000 // BATCH_SIZE
# np.ceil(959083/180000=5.33)
KERAS_EPOCH_MULTIPLIER = 6
VALIDATION_BATCH_SIZE = 38 # should be a multiple of 8
FEATURES = {
"data": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64),
"key": tf.io.FixedLenFeature([], tf.string),
}
LOSS_WEIGHTS = {"cross_entropy": 1.0, "center_loss": 0.01}
def decode_tfrecords(x):
features = tf.io.parse_single_example(x, FEATURES)
image = tf.io.decode_raw(features["data"], DATA_TYPE)
image = tf.reshape(image, DATA_SHAPE)
features["data"] = image
return features
def get_preprocessor():
preprocessor = tf.keras.Sequential(
[
# rotate before cropping
# 5 random degree rotation
layers.experimental.preprocessing.RandomRotation(5 / 360),
layers.experimental.preprocessing.RandomCrop(
height=OUTPUT_SHAPE[0], width=OUTPUT_SHAPE[1]
),
layers.experimental.preprocessing.RandomFlip("horizontal"),
# FIXED_STANDARDIZATION from https://github.com/davidsandberg/facenet
# [-0.99609375, 0.99609375]
layers.experimental.preprocessing.Rescaling(
scale=1 / 128, offset=-127.5 / 128
),
]
)
return preprocessor
def preprocess(preprocessor, features, augment=False):
image = features["data"]
label = features["label"]
image = preprocessor(image, training=augment)
return image, label
def prepare_dataset(tf_record_paths, batch_size, shuffle=False, augment=False):
ds = tf.data.Dataset.list_files(tf_record_paths, shuffle=shuffle)
ds = tf.data.TFRecordDataset(ds, num_parallel_reads=TFRECORD_PARALLEL_READ)
if shuffle:
# ignore order and read files as soon as they come in
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
ds = ds.with_options(ignore_order)
ds = ds.map(decode_tfrecords).prefetch(buffer_size=AUTOTUNE)
if shuffle:
ds = ds.shuffle(SHUFFLE_BUFFER).repeat(EPOCHS)
preprocessor = get_preprocessor()
ds = ds.batch(batch_size).map(
partial(preprocess, preprocessor, augment=augment), num_parallel_calls=AUTOTUNE,
)
# Use buffered prefecting on all datasets
return ds.prefetch(buffer_size=AUTOTUNE)
# return ds.apply(tf.data.experimental.prefetch_to_device(
# device, buffer_size=AUTOTUNE))
def accuracy_from_embeddings(labels, prelogits):
labels = tf.reshape(labels, (-1,))
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = predict_using_tensors(embeddings, labels)
return tf.math.equal(labels, predictions)
class CustomModel(tf.keras.Model):
def compile(
self,
cross_entropy,
center_loss,
loss_weights,
train_loss,
train_cross_entropy,
train_center_loss,
test_acc,
global_batch_size,
**kwargs,
):
super().compile(**kwargs)
self.cross_entropy = cross_entropy
self.center_loss = center_loss
self.loss_weights = loss_weights
self.train_loss = train_loss
self.train_cross_entropy = train_cross_entropy
self.train_center_loss = train_center_loss
self.test_acc = test_acc
self.global_batch_size = global_batch_size
def train_step(self, data):
images, labels = data
with tf.GradientTape() as tape:
logits, prelogits = self(images, training=True)
loss_cross = self.cross_entropy(labels, logits)
loss_center = self.center_loss(labels, prelogits)
loss = (
loss_cross * self.loss_weights[self.cross_entropy.name]
+ loss_center * self.loss_weights[self.center_loss.name]
)
unscaled_loss = tf.nn.compute_average_loss(
loss, global_batch_size=self.global_batch_size
)
loss = self.optimizer.get_scaled_loss(unscaled_loss)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
gradients = self.optimizer.get_unscaled_gradients(gradients)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.train_loss(unscaled_loss)
self.train_cross_entropy(loss_cross)
self.train_center_loss(loss_center)
return {
m.name: m.result()
for m in [self.train_loss, self.train_cross_entropy, self.train_center_loss]
}
def test_step(self, data):
images, labels = data
logits, prelogits = self(images, training=False)
self.test_acc(accuracy_from_embeddings(labels, prelogits))
return {m.name: m.result() for m in [self.test_acc]}
def create_model():
model = InceptionResNetV2(
include_top=True,
classes=N_CLASSES,
bottleneck=True,
input_shape=OUTPUT_SHAPE + (3,),
)
float32_layer = layers.Activation("linear", dtype="float32")
prelogits = model.get_layer("Bottleneck/BatchNorm").output
prelogits = CenterLossLayer(
n_classes=N_CLASSES, n_features=prelogits.shape[-1], name="centers"
)(prelogits)
prelogits = float32_layer(prelogits)
logits = float32_layer(model.get_layer("logits").output)
model = CustomModel(
inputs=model.input, outputs=[logits, prelogits], name=model.name
)
return model
def build_and_compile_model(global_batch_size):
model = create_model()
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy", reduction=tf.keras.losses.Reduction.NONE
)
center_loss = CenterLoss(
centers_layer=model.get_layer("centers"),
alpha=0.9,
name="center_loss",
reduction=tf.keras.losses.Reduction.NONE,
)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=LEARNING_RATE, rho=0.9, momentum=0.9, epsilon=1.0
)
optimizer = mixed_precision.LossScaleOptimizer(optimizer, loss_scale="dynamic")
train_loss = tf.keras.metrics.Mean(name="loss")
train_cross_entropy = tf.keras.metrics.Mean(name="cross_entropy")
train_center_loss = tf.keras.metrics.Mean(name="center_loss")
test_acc = tf.keras.metrics.Mean(name="accuracy")
model.compile(
optimizer=optimizer,
cross_entropy=cross_entropy,
center_loss=center_loss,
loss_weights=LOSS_WEIGHTS,
train_loss=train_loss,
train_cross_entropy=train_cross_entropy,
train_center_loss=train_center_loss,
test_acc=test_acc,
global_batch_size=global_batch_size,
)
return model
class CustomBackupAndRestore(tf.keras.callbacks.experimental.BackupAndRestore):
def __inti__(self, custom_objects, **kwargs):
super().__inti__(**kwargs)
self.custom_objects = custom_objects
self.custom_objects_path = os.path.join(self.backup_dir, "custom_objects.pkl")
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs=logs)
# pickle custom objects
with open(self.custom_objects_path, "wb") as f:
pickle.dump(self.custom_objects, f)
def on_train_begin(self, logs=None):
super().on_train_begin(logs=logs)
if not os.path.exists(self.custom_objects_path):
return
# load custom objects
with open(self.custom_objects_path, "rb") as f:
self.custom_objects = pickle.load(f)
def on_train_end(self, logs=None):
# do not delete backups
pass
def train_and_evaluate(tf_config):
os.environ["TF_CONFIG"] = json.dumps(tf_config)
per_worker_batch_size = BATCH_SIZE
num_workers = len(tf_config["cluster"]["worker"])
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
global_batch_size = per_worker_batch_size * num_workers
val_global_batch_size = VALIDATION_BATCH_SIZE * num_workers
train_ds = prepare_dataset(
TRAIN_TF_RECORD_PATHS, batch_size=global_batch_size, shuffle=True, augment=True
)
val_ds = prepare_dataset(
VALIDATION_TF_RECORD_PATHS,
batch_size=val_global_batch_size,
shuffle=False,
augment=False,
)
with strategy.scope():
model = build_and_compile_model(global_batch_size=global_batch_size)
val_metric_name = "val_accuracy"
def scheduler(epoch, lr):
# 20 epochs at 0.1, 10 at 0.01 and 5 0.001
# The epoch number here is Keras's which is different from actual epoch number
epoch = epoch // KERAS_EPOCH_MULTIPLIER
if epoch in range(20):
return 0.1
elif epoch in range(20, 30):
return 0.01
else:
return 0.001
callbacks = [
tf.keras.callbacks.ModelCheckpoint(f"{CHECKPOINT}/latest", verbose=1),
tf.keras.callbacks.ModelCheckpoint(
f"{CHECKPOINT}/best",
monitor=val_metric_name,
save_best_only=True,
mode="max",
verbose=1,
),
tf.keras.callbacks.TensorBoard(
log_dir=f"{CHECKPOINT}/logs", update_freq=15, profile_batch="10,50"
),
tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
# tf.keras.callbacks.ReduceLROnPlateau(
# monitor=val_metric_name, factor=0.2, patience=5, min_lr=0.001
# ),
tf.keras.callbacks.TerminateOnNaN(),
]
callbacks.append(CustomBackupAndRestore(backup_dir=f"{CHECKPOINT}/backup", custom_objects=callbacks))
model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS * KERAS_EPOCH_MULTIPLIER,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2 if os.environ.get("SGE_TASK_ID") else 1,
)
if __name__ == "__main__":
train_and_evaluate({})
Source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment