From 214fad7aaaf9ff246cde75c59026688be27a7814 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 17 Apr 2019 16:17:26 +0200
Subject: [PATCH] Add scripts for training keras models using keras API

---
 bob/learn/tensorflow/script/fit.py   | 107 +++++++++++++++++++++++++++
 bob/learn/tensorflow/script/keras.py |  15 ++++
 2 files changed, 122 insertions(+)
 create mode 100644 bob/learn/tensorflow/script/fit.py
 create mode 100644 bob/learn/tensorflow/script/keras.py

diff --git a/bob/learn/tensorflow/script/fit.py b/bob/learn/tensorflow/script/fit.py
new file mode 100644
index 00000000..f19776c5
--- /dev/null
+++ b/bob/learn/tensorflow/script/fit.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+"""Trains networks using Keras Models.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import click
+import json
+import logging
+import os
+import tensorflow as tf
+from bob.extension.scripts.click_helper import (
+    verbosity_option, ConfigCommand, ResourceOption, log_parameters)
+
+logger = logging.getLogger(__name__)
+
+
+@click.command(
+    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
+@click.option(
+    '--model',
+    '-m',
+    required=True,
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.model',
+    help='The keras model that will be trained.')
+@click.option(
+    '--train-input-fn',
+    '-i',
+    required=True,
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.input_fn',
+    help='A function that will return the training data as a tf.data.Dataset '
+    'or tf.data.Iterator. This will be given as `x` to '
+    'tf.keras.Model.fit.')
+@click.option(
+    '--epochs',
+    '-e',
+    default=1,
+    type=click.types.INT,
+    cls=ResourceOption,
+    help='Number of epochs to train model. See '
+    'tf.keras.Model.fit.')
+@click.option(
+    '--callbacks',
+    cls=ResourceOption,
+    multiple=True,
+    entry_point_group='bob.learn.tensorflow.callback',
+    help='List of tf.keras.callbacks. Used for callbacks '
+    'inside the training loop.')
+@click.option(
+    '--eval-input-fn',
+    '-i',
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.input_fn',
+    help='A function that will return the validation data as a tf.data.Dataset'
+    ' or tf.data.Iterator. This will be given as `validation_data` to '
+    'tf.keras.Model.fit.')
+@click.option(
+    '--class-weight',
+    '-c',
+    cls=ResourceOption,
+    help='See tf.keras.Model.fit.')
+@click.option(
+    '--initial-epoch',
+    default=0,
+    type=click.types.INT,
+    cls=ResourceOption,
+    help='See tf.keras.Model.fit.')
+@click.option(
+    '--steps-per-epoch',
+    type=click.types.INT,
+    cls=ResourceOption,
+    help='See tf.keras.Model.fit.')
+@click.option(
+    '--validation-steps',
+    type=click.types.INT,
+    cls=ResourceOption,
+    help='See tf.keras.Model.fit.')
+@verbosity_option(cls=ResourceOption)
+def fit(model, train_input_fn, epochs, verbose, callbacks, eval_input_fn,
+        class_weight, initial_epoch, steps_per_epoch, validation_steps,
+        **kwargs):
+    """Trains networks using Keras models."""
+    log_parameters(logger)
+
+    # Train
+    save_callback = [c for c in callbacks if isinstance(c, tf.keras.callbacks.ModelCheckpoint)]
+    model_dir = None
+    if save_callback:
+        model_dir = save_callback[0].filepath
+        logger.info("Training a model in %s", model_dir)
+    history = model.fit(
+        x=train_input_fn(),
+        epochs=epochs,
+        verbose=max(verbose, 2),
+        callbacks=list(callbacks) if callbacks else None,
+        validation_data=None if eval_input_fn is None else eval_input_fn(),
+        class_weight=class_weight,
+        initial_epoch=initial_epoch,
+        steps_per_epoch=steps_per_epoch,
+        validation_steps=validation_steps,
+    )
+    click.echo(history.history)
+    if model_dir is not None:
+        with open(os.path.join(model_dir, 'keras_fit_history.json'), 'w') as f:
+            json.dump(history.history, f)
diff --git a/bob/learn/tensorflow/script/keras.py b/bob/learn/tensorflow/script/keras.py
new file mode 100644
index 00000000..9f6ee9c2
--- /dev/null
+++ b/bob/learn/tensorflow/script/keras.py
@@ -0,0 +1,15 @@
+"""The main entry for bob keras (click-based) scripts.
+"""
+import click
+import pkg_resources
+from click_plugins import with_plugins
+from bob.extension.scripts.click_helper import AliasedGroup
+from .utils import eager_execution_option
+
+
+@with_plugins(pkg_resources.iter_entry_points('bob.learn.tensorflow.keras_cli'))
+@click.group(cls=AliasedGroup)
+@eager_execution_option()
+def keras():
+    """Keras-related commands."""
+    pass
-- 
GitLab