diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py index 12a939e87a3d300afa3f658db5ba048e1ad66650..aa3cf40e2fa62426b83ce08f7f6bca0c9479040d 100644 --- a/bob/ip/binseg/script/train.py +++ b/bob/ip/binseg/script/train.py @@ -98,7 +98,7 @@ logger = logging.getLogger(__name__) @click.option( "--pretrained-backbone", "-t", - help="URLs of a pre-trained model file that will be used to preset " + help="URL of a pre-trained model file that will be used to preset " "FCN weights (where relevant) before training starts " "(e.g. vgg16, mobilenetv2)", required=True, @@ -108,12 +108,31 @@ logger = logging.getLogger(__name__) "--batch-size", "-b", help="Number of samples in every batch (this parameter affects " - "memory requirements for the network)", + "memory requirements for the network). If the number of samples in " + "the batch is larger than the total number of samples available for " + "training, this value is truncated. If this number is smaller, then " + "batches of the specified size are created and fed to the network " + "until there are no more new samples to feed (epoch is finished). " + "If the total number of training samples is not a multiple of the " + "batch-size, the last batch will be smaller than the first, unless " + "--drop-incomplete--batch is set, in which case this batch is not used.", required=True, show_default=True, default=2, cls=ResourceOption, ) +@click.option( + "--drop-incomplete-batch/--no-drop-incomplete-batch", + "-D", + help="If set, then may drop the last batch in an epoch, in case it is " + "incomplete. If you set this option, you should also consider " + "increasing the total number of epochs of training, as the total number " + "of training steps may be reduced", + required=True, + show_default=True, + default=False, + cls=ResourceOption, +) @click.option( "--epochs", "-e", @@ -180,6 +199,7 @@ def train( epochs, pretrained_backbone, batch_size, + drop_incomplete_batch, criterion, dataset, checkpoint_period, @@ -193,10 +213,10 @@ def train( """Trains an FCN to perform binary segmentation using a supervised approach Training is performed for a configurable number of epochs, and generates at - least a final model (.pth file). It may also generate a number of - intermediate checkpoints. Checkpoints are model files (.pth files) that - are stored during the training and useful to resume the procedure in case - it stops abruptly. + least a final_model.pth. It may also generate a number of intermediate + checkpoints. Checkpoints are model files (.pth files) that are stored + during the training and useful to resume the procedure in case it stops + abruptly. """ if not os.path.exists(output_path): @@ -208,6 +228,7 @@ def train( dataset=dataset, batch_size=batch_size, shuffle=True, + drop_last=drop_incomplete_batch, pin_memory=torch.cuda.is_available(), )