From 214fad7aaaf9ff246cde75c59026688be27a7814 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 17 Apr 2019 16:17:26 +0200 Subject: [PATCH] Add scripts for training keras models using keras API --- bob/learn/tensorflow/script/fit.py | 107 +++++++++++++++++++++++++++ bob/learn/tensorflow/script/keras.py | 15 ++++ 2 files changed, 122 insertions(+) create mode 100644 bob/learn/tensorflow/script/fit.py create mode 100644 bob/learn/tensorflow/script/keras.py diff --git a/bob/learn/tensorflow/script/fit.py b/bob/learn/tensorflow/script/fit.py new file mode 100644 index 00000000..f19776c5 --- /dev/null +++ b/bob/learn/tensorflow/script/fit.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +"""Trains networks using Keras Models. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import click +import json +import logging +import os +import tensorflow as tf +from bob.extension.scripts.click_helper import ( + verbosity_option, ConfigCommand, ResourceOption, log_parameters) + +logger = logging.getLogger(__name__) + + +@click.command( + entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) +@click.option( + '--model', + '-m', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.model', + help='The keras model that will be trained.') +@click.option( + '--train-input-fn', + '-i', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.input_fn', + help='A function that will return the training data as a tf.data.Dataset ' + 'or tf.data.Iterator. This will be given as `x` to ' + 'tf.keras.Model.fit.') +@click.option( + '--epochs', + '-e', + default=1, + type=click.types.INT, + cls=ResourceOption, + help='Number of epochs to train model. See ' + 'tf.keras.Model.fit.') +@click.option( + '--callbacks', + cls=ResourceOption, + multiple=True, + entry_point_group='bob.learn.tensorflow.callback', + help='List of tf.keras.callbacks. Used for callbacks ' + 'inside the training loop.') +@click.option( + '--eval-input-fn', + '-i', + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.input_fn', + help='A function that will return the validation data as a tf.data.Dataset' + ' or tf.data.Iterator. This will be given as `validation_data` to ' + 'tf.keras.Model.fit.') +@click.option( + '--class-weight', + '-c', + cls=ResourceOption, + help='See tf.keras.Model.fit.') +@click.option( + '--initial-epoch', + default=0, + type=click.types.INT, + cls=ResourceOption, + help='See tf.keras.Model.fit.') +@click.option( + '--steps-per-epoch', + type=click.types.INT, + cls=ResourceOption, + help='See tf.keras.Model.fit.') +@click.option( + '--validation-steps', + type=click.types.INT, + cls=ResourceOption, + help='See tf.keras.Model.fit.') +@verbosity_option(cls=ResourceOption) +def fit(model, train_input_fn, epochs, verbose, callbacks, eval_input_fn, + class_weight, initial_epoch, steps_per_epoch, validation_steps, + **kwargs): + """Trains networks using Keras models.""" + log_parameters(logger) + + # Train + save_callback = [c for c in callbacks if isinstance(c, tf.keras.callbacks.ModelCheckpoint)] + model_dir = None + if save_callback: + model_dir = save_callback[0].filepath + logger.info("Training a model in %s", model_dir) + history = model.fit( + x=train_input_fn(), + epochs=epochs, + verbose=max(verbose, 2), + callbacks=list(callbacks) if callbacks else None, + validation_data=None if eval_input_fn is None else eval_input_fn(), + class_weight=class_weight, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps, + ) + click.echo(history.history) + if model_dir is not None: + with open(os.path.join(model_dir, 'keras_fit_history.json'), 'w') as f: + json.dump(history.history, f) diff --git a/bob/learn/tensorflow/script/keras.py b/bob/learn/tensorflow/script/keras.py new file mode 100644 index 00000000..9f6ee9c2 --- /dev/null +++ b/bob/learn/tensorflow/script/keras.py @@ -0,0 +1,15 @@ +"""The main entry for bob keras (click-based) scripts. +""" +import click +import pkg_resources +from click_plugins import with_plugins +from bob.extension.scripts.click_helper import AliasedGroup +from .utils import eager_execution_option + + +@with_plugins(pkg_resources.iter_entry_points('bob.learn.tensorflow.keras_cli')) +@click.group(cls=AliasedGroup) +@eager_execution_option() +def keras(): + """Keras-related commands.""" + pass -- GitLab