From 7381c41601271d58889bd03f29c72dbb13fc1385 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Tue, 7 Apr 2020 11:38:46 +0200 Subject: [PATCH] [various] Re-organize custom dataset classes; Improve documentation; Re-define prediction --- bob/ip/binseg/configs/datasets/csv.py | 85 ++++++++++ bob/ip/binseg/configs/datasets/folder.py | 51 ++++++ bob/ip/binseg/configs/datasets/imagefolder.py | 23 --- .../configs/datasets/imagefolderinference.py | 14 -- .../configs/datasets/imagefoldertest.py | 14 -- bob/ip/binseg/data/csvdataset.py | 157 ++++++++++++++++++ bob/ip/binseg/data/folderdataset.py | 83 +++++++++ bob/ip/binseg/data/imagefolder.py | 82 --------- bob/ip/binseg/data/imagefolderinference.py | 82 --------- bob/ip/binseg/engine/inferencer.py | 54 ++---- bob/ip/binseg/engine/predicter.py | 90 ---------- bob/ip/binseg/engine/predictor.py | 112 +++++++++++++ bob/ip/binseg/script/binseg.py | 51 +----- bob/ip/binseg/script/evaluate.py | 13 +- bob/ip/binseg/script/predict.py | 109 ++++++++++++ conda/meta.yaml | 1 + doc/api.rst | 12 +- doc/cli.rst | 16 +- doc/conf.py | 23 --- doc/datasets.rst | 39 ----- doc/evaluation.rst | 69 ++++++-- doc/training.rst | 18 ++ requirements.txt | 1 + setup.py | 5 +- 24 files changed, 721 insertions(+), 483 deletions(-) create mode 100644 bob/ip/binseg/configs/datasets/csv.py create mode 100644 bob/ip/binseg/configs/datasets/folder.py delete mode 100644 bob/ip/binseg/configs/datasets/imagefolder.py delete mode 100644 bob/ip/binseg/configs/datasets/imagefolderinference.py delete mode 100644 bob/ip/binseg/configs/datasets/imagefoldertest.py create mode 100644 bob/ip/binseg/data/csvdataset.py create mode 100644 bob/ip/binseg/data/folderdataset.py delete mode 100644 bob/ip/binseg/data/imagefolder.py delete mode 100644 bob/ip/binseg/data/imagefolderinference.py delete mode 100644 bob/ip/binseg/engine/predicter.py create mode 100644 bob/ip/binseg/engine/predictor.py create mode 100644 bob/ip/binseg/script/predict.py diff --git a/bob/ip/binseg/configs/datasets/csv.py b/bob/ip/binseg/configs/datasets/csv.py new file mode 100644 index 00000000..410008f8 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/csv.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Example CSV-based filelist dataset + +In case you have your own dataset that is organized on your filesystem, this +configuration shows an example setup so you can feed such files and +ground-truth data to train one of the available network models or to evaluate +it. + +You must write CSV based file (e.g. using comma as separator) that describes +the image and ground-truth locations for each image pair on your dataset. +Relative paths are considered with respect to the location of the CSV file +itself by default, also pass the ``root_path`` parameter to the +:py:class:`bob.ip.binseg.data.csvdataset.CSVDataset` object constructor. So, +for example, if you have a structure like this: + +.. code-block:: text + + ├── images + ├── image_1.png + ├── ... + └── image_n.png + └── ground-truth + ├── gt_1.png + ├── ... + └── gt_n.png + +Then create a file in the same level of ``images`` and ``ground-truth`` with +the following contents: + +.. code-block:: text + + images/image_1.png,ground-truth/gt_1.png + ...,... + images/image_n.png,ground-truth/gt_n.png + +To create a dataset without ground-truth (e.g., for prediction purposes), then +omit the second column on the CSV file. + +Use the path leading to the CSV file and replace ``<path.csv>`` on the example +code for this configuration, that you must copy locally to make changes: + +.. code-block:: sh + + $ bob binseg config copy csv-dataset-example mydataset.py + # edit mydataset.py as explained here + +Fine-tune the transformations for your particular purpose: + + 1. If you are training a new model, you may add random image + transformations + 2. If you are running prediction, you should/may skip random image + transformations + +Keep in mind that specific models require that you feed images respecting +certain restrictions (input dimensions, image centering, etc.). Check the +configuration that was used to train models and try to match it as well as +possible. + +See: + +* :py:class:`bob.ip.binseg.data.csvdataset.CSVDataset` for operational details. +* :py:class:`bob.ip.binseg.data.folderdataset.FolderDataset` for an alternative + implementation of an easier to generate **prediction** dataset. + +""" + +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.csvdataset import CSVDataset + +# add your transforms below - these are just examples +# keep the ``ToTensor()`` transform at the end +transforms = Compose( + [ + #CenterCrop((544, 544)), + #RandomHFlip(), + #RandomVFlip(), + #RandomRotation(), + #ColorJitter(), + ToTensor(), + ] +) + +#dataset = CSVDataset("<path.csv>", check_available=False, transform=transforms) diff --git a/bob/ip/binseg/configs/datasets/folder.py b/bob/ip/binseg/configs/datasets/folder.py new file mode 100644 index 00000000..d34c2797 --- /dev/null +++ b/bob/ip/binseg/configs/datasets/folder.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Example self-scanning folder-based dataset + +In case you have data that is organized on your filesystem, this configuration +shows an example setup so you can feed such files **without** ground-truth to +predict vessel probalities using one of our trained models. There can be any +number of images within the root folder of your dataset, with any kind of +subfolder arrangements. For example: + +.. code-block:: text + + ├── image_1.png + └── subdir1 + ├── image_subdir_1.jpg + ├── ... + └── image_subdir_k.jpg + ├── ... + └── image_n.png + +Use the path leading to the root of your dataset, and replace ``<path.csv>`` on +the example code for this configuration, that you must copy locally to make +changes: + +.. code-block:: sh + + $ bob binseg config copy folder-dataset-example mydataset.py + # edit mydataset.py as explained here + +Fine-tune the transformations for your particular purpose. + +Keep in mind that specific models require that you feed images respecting +certain restrictions (input dimensions, image centering, etc.). Check the +configuration that was used to train models and try to match it as well as +possible. +""" + +from bob.ip.binseg.data.transforms import * +from bob.ip.binseg.data.folderdataset import FolderDataset + +# add your transforms below - these are just examples +# keep the ``ToTensor()`` transform at the end +transforms = Compose( + [ + #CenterCrop((544, 544)), + ToTensor(), + ] +) + +#dataset = FolderDataset("<path.csv>", glob="*.*", transform=transforms) diff --git a/bob/ip/binseg/configs/datasets/imagefolder.py b/bob/ip/binseg/configs/datasets/imagefolder.py deleted file mode 100644 index d73de0c4..00000000 --- a/bob/ip/binseg/configs/datasets/imagefolder.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.imagefolder import ImageFolder - -#### Config #### - -# add your transforms below -transforms = Compose( - [ - CenterCrop((544, 544)), - RandomHFlip(), - RandomVFlip(), - RandomRotation(), - ColorJitter(), - ToTensor(), - ] -) - -# PyTorch dataset -path = "/path/to/dataset" -dataset = ImageFolder(path, transform=transforms) diff --git a/bob/ip/binseg/configs/datasets/imagefolderinference.py b/bob/ip/binseg/configs/datasets/imagefolderinference.py deleted file mode 100644 index 869f7e51..00000000 --- a/bob/ip/binseg/configs/datasets/imagefolderinference.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.imagefolderinference import ImageFolderInference - -#### Config #### - -# add your transforms below -transforms = Compose([ToRGB(), CenterCrop((544, 544)), ToTensor()]) - -# PyTorch dataset -path = "/path/to/folder/containing/images" -dataset = ImageFolderInference(path, transform=transforms) diff --git a/bob/ip/binseg/configs/datasets/imagefoldertest.py b/bob/ip/binseg/configs/datasets/imagefoldertest.py deleted file mode 100644 index 474b0384..00000000 --- a/bob/ip/binseg/configs/datasets/imagefoldertest.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from bob.ip.binseg.data.transforms import * -from bob.ip.binseg.data.imagefolder import ImageFolder - -#### Config #### - -# add your transforms below -transforms = Compose([CenterCrop((544, 544)), ToTensor()]) - -# PyTorch dataset -path = "/path/to/testdataset" -dataset = ImageFolder(path, transform=transforms) diff --git a/bob/ip/binseg/data/csvdataset.py b/bob/ip/binseg/data/csvdataset.py new file mode 100644 index 00000000..047bba6d --- /dev/null +++ b/bob/ip/binseg/data/csvdataset.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# coding=utf-8 + +import os +import csv + +from PIL import Image + +from torch.utils.data import Dataset +import torch +import torchvision.transforms.functional as VF + +import bob.io.base + + +class CSVDataset(Dataset): + """ + Generic filelist dataset + + To create a new dataset, you only need to provide a CSV formatted filelist + using any separator (e.g. comma, space, semi-colon) including, in the first + column, a path pointing to the input image, and in the second column, a + path pointing to the ground truth. Relative paths are interpreted with + respect to the location where the CSV file is or to an optional + ``root_path`` parameter, that must be also provided. + + There are no requirements concerning image or ground-truth homogenity. + Anything that can be loaded by our image and data loaders is OK. Use + a non-white character as separator. Here is a far too complicated example: + + .. code-block:: text + + /path/to/image1.jpg,/path/to/ground-truth1.png + /possibly/another/path/to/image 2.PNG,/path/to/that/ground-truth.JPG + relative/path/image3.gif,relative/path/gt3.gif + + .. important:: + + Images are converted to RGB after readout via PIL. Ground-truth data is + loaded using the same technique, but converted to mode ``1`` instead of + ``RGB``. If ground-truth data is encoded as an HDF5 file, we use + instead :py:func:`bob.io.base.load`, and then converted it to 32-bit + float data. + + To generate a dataset without ground-truth (e.g. for prediction tasks), + then omit the second column. + + + Parameters + ---------- + path : str + Full path to the file containing the dataset description, in CSV + format as described above + + root_path : :py:class:`str`, Optional + Path to a common filesystem root where files with relative paths should + be sitting. If not set, then we use the absolute path leading to the + CSV file as ``root_path`` + + check_available : :py:class:`bool`, Optional + If set to ``True``, then checks if files in the file list are + available. Otherwise does not. + + transform : :py:class:`.transforms.Compose`, Optional + a composition of transformations to be applied to **both** image and + ground-truth data. Notice that image changing transformations such as + :py:class:`.transforms.ColorJitter` are only applied to the image and + **not** to ground-truth. + + """ + + def __init__(self, path, root_path=None, check_available=True, transform=None): + + self.root_path = root_path or os.path.dirname(path) + self.transform = transform + + def _make_abs_path(root, s): + retval = [] + for p in s: + if not os.path.isabs(p): + retval.append(os.path.join(root, p)) + return retval + + with open(path, newline='') as f: + reader = csv.reader(f) + self.data = [_make_abs_path(self.root_path, k) for k in reader] + + # check if all files are readable, warn otherwise + if check_available: + errors = 0 + for s in self.data: + for p in s: + if not os.path.exists(p): + errors += 1 + logger.error(f"Cannot find {p}") + assert errors == 0, f"There {errors} files which cannot be " \ + f"found on your filelist ({path}) dataset" + + # check all data entries have the same size + assert all(len(k) == len(self.data[0]) for k in self.data), \ + f"There is an inconsistence on your dataset - not all " \ + f"entries have length=={len(self.data[0])}" + + + def has_ground_truth(self): + """Tells if this dataset has ground-truth or not""" + return len(self.data[0]) > 1 + + + def __len__(self): + """ + + Returns + ------- + + length : int + size of the dataset + """ + + return len(self.data) + + def __getitem__(self, index): + """ + + Parameters + ---------- + index : int + + Returns + ------- + sample : list + ``[name, img, gt]`` or ``[name, img]`` depending on whether this + dataset has or not ground-truth. + """ + + sample_paths = self.data[index] + + img_path = sample_paths[0] + gt_path = sample_paths[1] if len(sample_paths) > 1 else None + + # images are converted to RGB mode automatically + sample = [Image.open(img_path).convert(mode="RGB")] + + if gt_path is not None: + if gt_path.endswith(".hdf5"): + gt = bob.io.base.load(str(gt_path)).astype("float32") + # a bit hackish, but will get what we need + gt = VF.to_pil_image(torch.from_numpy(gt)) + else: + gt = Image.open(gt_path) + gt = gt.convert(mode="1", dither=None) + sample = sample + [gt] + + if self.transform: + sample = self.transform(*sample) + + return [img_path] + sample diff --git a/bob/ip/binseg/data/folderdataset.py b/bob/ip/binseg/data/folderdataset.py new file mode 100644 index 00000000..19366884 --- /dev/null +++ b/bob/ip/binseg/data/folderdataset.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# coding=utf-8 + +from pathlib import Path + +from PIL import Image + +from torch.utils.data import Dataset + + +def _find_files(data_path, glob): + """ + Recursively retrieves file lists from a given path, matching a given glob + + This function will use :py:meth:`pathlib.Path.rglob`, together with the + provided glob pattern to search for anything the desired filename. + """ + + data_path = Path(data_path) + return sorted(list(data_path.rglob(glob))) + + +class FolderDataset(Dataset): + """ + Generic image folder containing images for prediction + + .. important:: + + This implementation, contrary to its sister + :py:class:`.csvdataset.CSVDataset`, does not *automatically* convert + the input image to RGB, before passing it to the transforms, so it is + possible to accomodate a wider range of input types (e.g. 16-bit PNG + images). + + Parameters + ---------- + + path : str + full path to root of dataset + + glob : str + glob that can be used to filter-down files to be loaded on the provided + path + + transform : :py:class:`.transforms.Compose`, Optional + a composition of transformations to be applied to **both** image and + ground-truth data. Notice that image changing transformations such as + :py:class:`.transforms.ColorJitter` are only applied to the image and + **not** to ground-truth. + + """ + + def __init__(self, path, glob="*", transform=None): + self.transform = transform + self.path = path + self.data = _find_files(path, glob) + + def __len__(self): + """ + Returns + ------- + int + size of the dataset + """ + + return len(self.data) + + def __getitem__(self, index): + """ + Parameters + ---------- + index : int + + Returns + ------- + sample : list + [name, img] + """ + + sample = [Image.open(self.data[index])] + if self.transform: + sample = self.transform(*sample) + return [self.data[index].relative_to(self.path).as_posix()] + sample diff --git a/bob/ip/binseg/data/imagefolder.py b/bob/ip/binseg/data/imagefolder.py deleted file mode 100644 index 9794f5c1..00000000 --- a/bob/ip/binseg/data/imagefolder.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -from torch.utils.data import Dataset -from pathlib import Path -import numpy as np -from PIL import Image -import torch -import torchvision.transforms.functional as VF -import bob.io.base - - -def get_file_lists(data_path): - data_path = Path(data_path) - - image_path = data_path.joinpath("images") - image_file_names = np.array(sorted(list(image_path.glob("*")))) - - gt_path = data_path.joinpath("gt") - gt_file_names = np.array(sorted(list(gt_path.glob("*")))) - return image_file_names, gt_file_names - - -class ImageFolder(Dataset): - """ - Generic ImageFolder dataset, that contains two folders: - - * ``images`` (vessel images) - * ``gt`` (ground-truth labels) - - - Parameters - ---------- - path : str - full path to root of dataset - - """ - - def __init__(self, path, transform=None): - self.transform = transform - self.img_file_list, self.gt_file_list = get_file_lists(path) - - def __len__(self): - """ - Returns - ------- - int - size of the dataset - """ - return len(self.img_file_list) - - def __getitem__(self, index): - """ - Parameters - ---------- - index : int - - Returns - ------- - list - dataitem [img_name, img, gt, mask] - """ - img_path = self.img_file_list[index] - img_name = img_path.name - img = Image.open(img_path).convert(mode="RGB") - - gt_path = self.gt_file_list[index] - if gt_path.suffix == ".hdf5": - gt = bob.io.base.load(str(gt_path)).astype("float32") - # not elegant but since transforms require PIL images we do this hacky workaround here - gt = torch.from_numpy(gt) - gt = VF.to_pil_image(gt).convert(mode="1", dither=None) - else: - gt = Image.open(gt_path).convert(mode="1", dither=None) - - sample = [img, gt] - - if self.transform: - sample = self.transform(*sample) - - sample.insert(0, img_name) - - return sample diff --git a/bob/ip/binseg/data/imagefolderinference.py b/bob/ip/binseg/data/imagefolderinference.py deleted file mode 100644 index 6f21d2e9..00000000 --- a/bob/ip/binseg/data/imagefolderinference.py +++ /dev/null @@ -1,82 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -from torch.utils.data import Dataset -from pathlib import Path -import numpy as np -from PIL import Image - - -def get_file_lists(data_path, glob): - """ - Recursively retrieves file lists from a given path, matching a given glob - - This function will use :py:meth:`pathlib.Path.rglob`, together with the - provided glob pattern to search for anything the desired filename. - """ - - data_path = Path(data_path) - image_file_names = np.array(sorted(list(data_path.rglob(glob)))) - return image_file_names - - -class ImageFolderInference(Dataset): - """ - Generic ImageFolder containing images for inference - - Notice that this implementation, contrary to its sister - :py:class:`.ImageFolder`, does not *automatically* - convert the input image to RGB, before passing it to the transforms, so it - is possible to accomodate a wider range of input types (e.g. 16-bit PNG - images). - - Parameters - ---------- - path : str - full path to root of dataset - - glob : str - glob that can be used to filter-down files to be loaded on the provided - path - - transform : list - List of transformations to apply to every input sample - - """ - - def __init__(self, path, glob="*", transform=None): - self.transform = transform - self.path = path - self.img_file_list = get_file_lists(path, glob) - - def __len__(self): - """ - Returns - ------- - int - size of the dataset - """ - return len(self.img_file_list) - - def __getitem__(self, index): - """ - Parameters - ---------- - index : int - - Returns - ------- - list - dataitem [img_name, img] - """ - img_path = self.img_file_list[index] - img_name = img_path.relative_to(self.path).as_posix() - img = Image.open(img_path) - - sample = [img] - - if self.transform: - sample = self.transform(*sample) - - sample.insert(0, img_name) - - return sample diff --git a/bob/ip/binseg/engine/inferencer.py b/bob/ip/binseg/engine/inferencer.py index 4eae75ce..91191ab0 100644 --- a/bob/ip/binseg/engine/inferencer.py +++ b/bob/ip/binseg/engine/inferencer.py @@ -20,28 +20,33 @@ import logging logger = logging.getLogger(__name__) -def batch_metrics(predictions, ground_truths, names, output_folder, logger): +def batch_metrics(predictions, ground_truths, names, output_folder): """ Calculates metrics on the batch and saves it to disc + Parameters ---------- + predictions : :py:class:`torch.Tensor` tensor with pixel-wise probabilities + ground_truths : :py:class:`torch.Tensor` tensor with binary ground-truth + names : list list of file names + output_folder : str output path - logger : :py:class:`logging.Logger` - python logger + Returns ------- - list - list containing batch metrics: ``[name, threshold, precision, recall, specificity, accuracy, jaccard, f1_score]`` + metrics : tuple + A tuple containing batch metrics: ``(name, threshold, precision, recall, specificity, accuracy, jaccard, f1_score)`` """ + step_size = 0.01 batch_metrics = [] @@ -101,17 +106,24 @@ def save_probability_images(predictions, names, output_folder, logger): """ Saves probability maps as image in the same format as the test image + Parameters ---------- + predictions : :py:class:`torch.Tensor` tensor with pixel-wise probabilities + names : list list of file names + output_folder : str output path + logger : :py:class:`logging.Logger` python logger + """ + images_subfolder = os.path.join(output_folder, "images") for j in range(predictions.size()[0]): img = VF.to_pil_image(predictions.cpu().data[j]) @@ -124,39 +136,9 @@ def save_probability_images(predictions, names, output_folder, logger): img.save(fullpath) -def save_hdf(predictions, names, output_folder, logger): - """ - Saves probability maps as image in the same format as the test image - - Parameters - ---------- - predictions : :py:class:`torch.Tensor` - tensor with pixel-wise probabilities - names : list - list of file names - output_folder : str - output path - logger : :py:class:`logging.Logger` - python logger - """ - hdf5_subfolder = os.path.join(output_folder, "hdf5") - if not os.path.exists(hdf5_subfolder): - os.makedirs(hdf5_subfolder) - for j in range(predictions.size()[0]): - img = predictions.cpu().data[j].squeeze(0).numpy() - filename = "{}.hdf5".format(names[j].split(".")[0]) - fullpath = os.path.join(hdf5_subfolder, filename) - logger.info("saving {}".format(filename)) - fulldir = os.path.dirname(fullpath) - if not os.path.exists(fulldir): - os.makedirs(fulldir) - bob.io.base.save(img, fullpath) - - def do_inference(model, data_loader, device, output_folder=None): - """ - Run inference and calculate metrics + Runs inference and calculate metrics Parameters --------- diff --git a/bob/ip/binseg/engine/predicter.py b/bob/ip/binseg/engine/predicter.py deleted file mode 100644 index 3aabffcc..00000000 --- a/bob/ip/binseg/engine/predicter.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import time -import datetime -import numpy as np -import torch -from tqdm import tqdm - -from bob.ip.binseg.engine.inferencer import save_probability_images -from bob.ip.binseg.engine.inferencer import save_hdf - -import logging -logger = logging.getLogger(__name__) - - -def do_predict(model, data_loader, device, output_folder=None): - - """ - Run inference and calculate metrics - - Parameters - --------- - model : :py:class:`torch.nn.Module` - neural network model (e.g. DRIU, HED, UNet) - data_loader : py:class:`torch.torch.utils.data.DataLoader` - device : str - device to use ``'cpu'`` or ``'cuda'`` - output_folder : str - """ - logger.info("Start evaluation") - logger.info("Output folder: {}, Device: {}".format(output_folder, device)) - results_subfolder = os.path.join(output_folder, "results") - os.makedirs(results_subfolder, exist_ok=True) - - model.eval().to(device) - # Sigmoid for probabilities - sigmoid = torch.nn.Sigmoid() - - # Setup timers - start_total_time = time.time() - times = [] - - for samples in tqdm(data_loader): - names = samples[0] - images = samples[1].to(device) - with torch.no_grad(): - start_time = time.perf_counter() - - outputs = model(images) - - # necessary check for hed architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - probabilities = sigmoid(outputs) - - batch_time = time.perf_counter() - start_time - times.append(batch_time) - logger.info("Batch time: {:.5f} s".format(batch_time)) - - # Create probability images - save_probability_images(probabilities, names, output_folder, logger) - # Save hdf5 - save_hdf(probabilities, names, output_folder, logger) - - # Report times - total_inference_time = str(datetime.timedelta(seconds=int(sum(times)))) - average_batch_inference_time = np.mean(times) - total_evalution_time = str( - datetime.timedelta(seconds=int(time.time() - start_total_time)) - ) - - logger.info( - "Average batch inference time: {:.5f}s".format(average_batch_inference_time) - ) - - times_file = "Times.txt" - logger.info("saving {}".format(times_file)) - - with open(os.path.join(results_subfolder, times_file), "w+") as outfile: - date = datetime.datetime.now() - outfile.write("Date: {} \n".format(date.strftime("%Y-%m-%d %H:%M:%S"))) - outfile.write("Total evaluation run-time: {} \n".format(total_evalution_time)) - outfile.write( - "Average batch inference time: {} \n".format(average_batch_inference_time) - ) - outfile.write("Total inference time: {} \n".format(total_inference_time)) diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py new file mode 100644 index 00000000..016ce0bc --- /dev/null +++ b/bob/ip/binseg/engine/predictor.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import time +import datetime +import numpy as np +import torch +from tqdm import tqdm + +import logging +logger = logging.getLogger(__name__) + + +def save_hdf5(predictions, names, output_folder): + """ + Saves probability maps as image in the same format as the test image + + + Parameters + ---------- + predictions : :py:class:`torch.Tensor` + tensor with pixel-wise probabilities + + names : list + list of file names + + output_folder : str + output path + + """ + + for j in range(predictions.size()[0]): + + img = predictions.cpu().data[j].squeeze(0).numpy() + filename = "{}.hdf5".format(names[j].split(".")[0]) + fullpath = os.path.join(output_folder, filename) + logger.info(f"saving {filename}") + fulldir = os.path.dirname(fullpath) + if not os.path.exists(fulldir): + os.makedirs(fulldir, exist_ok=True) + bob.io.base.save(img, fullpath) + + +def run(model, data_loader, device, output_folder): + """ + Runs inference on input data, outputs HDF5 files with predictions + + Parameters + --------- + model : :py:class:`torch.nn.Module` + neural network model (e.g. driu, hed, unet) + + data_loader : py:class:`torch.torch.utils.data.DataLoader` + + device : str + device to use ``cpu`` or ``cuda:0`` + + output_folder : str + folder where to store output images (HDF5 files) + + """ + + logger.info("Start prediction") + logger.info(f"Output folder: {output_folder}, Device: {device}") + os.makedirs(output_folder, exist_ok=True) + + model.eval().to(device) + # Sigmoid for probabilities + sigmoid = torch.nn.Sigmoid() + + # Setup timers + start_total_time = time.time() + times = [] + len_samples = [] + + for samples in tqdm( + data_loader, desc="batches", leave=False, disable=None, + ): + + names = samples[0] + images = samples[1].to(device) + + with torch.no_grad(): + + start_time = time.perf_counter() + outputs = model(images) + + # necessary check for HED architecture that uses several outputs + # for loss calculation instead of just the last concatfuse block + if isinstance(outputs, list): + outputs = outputs[-1] + + probabilities = sigmoid(outputs) + + batch_time = time.perf_counter() - start_time + times.append(batch_time) + len_samples.append(len(images)) + + save_hdf5(probabilities, names, output_folder) + + logger.info("End prediction") + + # report operational summary + total_time = datetime.timedelta(seconds=int(time.time() - start_total_time)) + logger.info(f"Total time: {total_time}") + + average_batch_time = np.mean(times) + logger.info(f"Average batch time: {average_batch_time:g}s\n") + + average_image_time = np.sum(times * len_samples) / float(sum(len_samples)) + logger.info(f"Average image time: {average_image_time:g}s\n") diff --git a/bob/ip/binseg/script/binseg.py b/bob/ip/binseg/script/binseg.py index a3a6a27f..8f63c83c 100644 --- a/bob/ip/binseg/script/binseg.py +++ b/bob/ip/binseg/script/binseg.py @@ -25,10 +25,9 @@ from torch.utils.data import DataLoader from bob.ip.binseg.utils.plot import plot_overview from bob.ip.binseg.utils.click import OptionEatAll from bob.ip.binseg.utils.rsttable import create_overview_grid -from bob.ip.binseg.utils.plot import metricsviz, overlay, savetransformedtest +from bob.ip.binseg.utils.plot import metricsviz, savetransformedtest from bob.ip.binseg.utils.transformfolder import transformfolder as transfld from bob.ip.binseg.utils.evaluate import do_eval -from bob.ip.binseg.engine.predicter import do_predict logger = logging.getLogger(__name__) @@ -125,54 +124,6 @@ def transformfolder(source_path, target_path, transforms, **kwargs): transfld(source_path, target_path, transforms) -# Run inference and create predictions only (no ground truth available) -@binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand) -@click.option( - "--output-path", "-o", required=True, default="output", cls=ResourceOption -) -@click.option("--model", "-m", required=True, cls=ResourceOption) -@click.option("--dataset", "-d", required=True, cls=ResourceOption) -@click.option("--batch-size", "-b", required=True, default=2, cls=ResourceOption) -@click.option( - "--device", - "-d", - help='A string indicating the device to use (e.g. "cpu" or "cuda:0"', - show_default=True, - required=True, - default="cpu", - cls=ResourceOption, -) -@click.option( - "--weight", - "-w", - help="Path or URL to pretrained model", - required=False, - default=None, - cls=ResourceOption, -) -@verbosity_option(cls=ResourceOption) -def predict(model, output_path, device, batch_size, dataset, weight, **kwargs): - """ Run inference and evaluate the model performance """ - - # PyTorch dataloader - data_loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=False, - pin_memory=torch.cuda.is_available(), - ) - - # checkpointer, load last model in dir - checkpointer = DetectronCheckpointer( - model, save_dir=output_path, save_to_disk=False - ) - checkpointer.load(weight) - do_predict(model, data_loader, device, output_path) - - # Overlayed images - overlay(dataset=dataset, output_path=output_path) - - # Evaluate only. Runs evaluation on predicted probability maps (--prediction-folder) @binseg.command(entry_point_group="bob.ip.binseg.config", cls=ConfigCommand) @click.option( diff --git a/bob/ip/binseg/script/evaluate.py b/bob/ip/binseg/script/evaluate.py index bbbe8f4d..a036084c 100644 --- a/bob/ip/binseg/script/evaluate.py +++ b/bob/ip/binseg/script/evaluate.py @@ -1,9 +1,6 @@ #!/usr/bin/env python # coding=utf-8 -import os -import pkg_resources - import click from click_plugins import with_plugins @@ -36,6 +33,14 @@ logger = logging.getLogger(__name__) """, ) +@click.option( + "--output-path", + "-o", + help="Path where to store the generated model (created if does not exist)", + required=True, + default="results", + cls=ResourceOption, +) @click.option( "--model", "-m", @@ -76,7 +81,7 @@ logger = logging.getLogger(__name__) cls=ResourceOption, ) @verbosity_option(cls=ResourceOption) -def evaluate(model, output_path, device, batch_size, dataset, weight, **kwargs): +def evaluate(output_path, model, dataset, batch_size, device, weight, **kwargs): """Evaluates an FCN on a binary segmentation task. """ diff --git a/bob/ip/binseg/script/predict.py b/bob/ip/binseg/script/predict.py new file mode 100644 index 00000000..8fe08c40 --- /dev/null +++ b/bob/ip/binseg/script/predict.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# coding=utf-8 + +import click +from click_plugins import with_plugins + +import torch +from torch.utils.data import DataLoader + +from bob.extension.scripts.click_helper import ( + verbosity_option, + ConfigCommand, + ResourceOption, + AliasedGroup, +) + +from ..engine.predictor import run + +import logging +logger = logging.getLogger(__name__) + + +@click.command( + entry_point_group="bob.ip.binseg.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Runs prediction on an existing dataset configuration: + + $ bob binseg predict -vv m2unet drive-test --weight=path/to/model_final.pth --output-path=path/to/predictions +\b + 2. To run prediction on a folder with your own images, you must first + specify resizing, cropping, etc, so that the image can be correctly + input to the model. Failing to do so will likely result in poor + performance. To figure out such specifications, you must consult the + dataset configuration used for **training** the provided model. Once + you figured this out, do the following: + + $ bob binseg config copy image-folder myfolder.py + # modify "myfolder.py" to include the base path and required transforms + $ bob binseg predict -vv m2unet myfolder.py --weight=path/to/model_final.pth --output-path=path/to/predictions +""", +) +@click.option( + "--output-path", + "-o", + help="Path where to store the generated model (created if does not exist)", + required=True, + default="results", + cls=ResourceOption, +) +@click.option( + "--model", + "-m", + help="A torch.nn.Module instance implementing the network to be evaluated", + required=True, + cls=ResourceOption, +) +@click.option( + "--dataset", + "-d", + help="A torch.utils.data.dataset.Dataset instance implementing a dataset to be used for evaluating the model, possibly including all pre-processing pipelines required", + required=True, + cls=ResourceOption, +) +@click.option( + "--batch-size", + "-b", + help="Number of samples in every batch (this parameter affects memory requirements for the network)", + required=True, + show_default=True, + default=1, + cls=ResourceOption, +) +@click.option( + "--device", + "-d", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, +) +@click.option( + "--weight", + "-w", + help="Path or URL to pretrained model file (.pth extension)", + required=True, + cls=ResourceOption, +) +@verbosity_option(cls=ResourceOption) +def predict(output_path, model, dataset, batch_size, device, weight, **kwargs): + """Predicts vessel map (probabilities) on input images""" + + # PyTorch dataloader + data_loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=torch.cuda.is_available(), + ) + + # checkpointer, loads pre-fit model + checkpointer = DetectronCheckpointer(model, save_dir=output_path, + save_to_disk=False) + checkpointer.load(weight) + + run(model, data_loader, device, output_path) diff --git a/conda/meta.yaml b/conda/meta.yaml index 753656b2..bc094a41 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -40,6 +40,7 @@ requirements: - tqdm - tabulate - bob.core + - bob.io.base test: imports: diff --git a/doc/api.rst b/doc/api.rst index b0e5fb83..3643295b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -17,8 +17,8 @@ Data Manipulation bob.ip.binseg.data bob.ip.binseg.data.binsegdataset - bob.ip.binseg.data.imagefolder - bob.ip.binseg.data.imagefolderinference + bob.ip.binseg.data.csvdataset + bob.ip.binseg.data.folderdataset bob.ip.binseg.data.transforms @@ -31,7 +31,7 @@ Engines bob.ip.binseg.engine bob.ip.binseg.engine.adabound bob.ip.binseg.engine.inferencer - bob.ip.binseg.engine.predicter + bob.ip.binseg.engine.predictor bob.ip.binseg.engine.ssltrainer bob.ip.binseg.engine.trainer @@ -125,6 +125,7 @@ Datasets bob.ip.binseg.configs.datasets.chasedb1544 bob.ip.binseg.configs.datasets.chasedb1608 bob.ip.binseg.configs.datasets.chasedb1test + bob.ip.binseg.configs.datasets.csv bob.ip.binseg.configs.datasets.drionsdb bob.ip.binseg.configs.datasets.drionsdbtest bob.ip.binseg.configs.datasets.dristhigs1cup @@ -147,6 +148,7 @@ Datasets bob.ip.binseg.configs.datasets.drivestareiostarhrf960 bob.ip.binseg.configs.datasets.drivestareiostarhrf960sslchase bob.ip.binseg.configs.datasets.drivetest + bob.ip.binseg.configs.datasets.folder bob.ip.binseg.configs.datasets.hrf bob.ip.binseg.configs.datasets.hrf1024 bob.ip.binseg.configs.datasets.hrf1168 @@ -156,9 +158,6 @@ Datasets bob.ip.binseg.configs.datasets.hrf608 bob.ip.binseg.configs.datasets.hrf960 bob.ip.binseg.configs.datasets.hrftest - bob.ip.binseg.configs.datasets.imagefolder - bob.ip.binseg.configs.datasets.imagefolderinference - bob.ip.binseg.configs.datasets.imagefoldertest bob.ip.binseg.configs.datasets.iostarod bob.ip.binseg.configs.datasets.iostarodtest bob.ip.binseg.configs.datasets.iostarvessel @@ -185,7 +184,6 @@ Datasets bob.ip.binseg.configs.datasets.starechasedb1iostarhrf544ssldrive bob.ip.binseg.configs.datasets.staretest - Test Units ---------- diff --git a/doc/cli.rst b/doc/cli.rst index d65aaf79..fc377892 100644 --- a/doc/cli.rst +++ b/doc/cli.rst @@ -71,12 +71,26 @@ evaluation tests or for inference. .. command-output:: bob binseg train --help +.. _bob.ip.binseg.cli.predict: + +FCN Inference +------------- + +Inference takes as input a PyTorch_ model and generates output probabilities as +HDF5 files. The probability map has the same size as the input and indicates, +from 0 to 1 (floating-point number), the probability of a vessel in that pixel, +from less probable (0.0) to more probable (1.0). + +.. command-output:: bob binseg predict --help + + .. _bob.ip.binseg.cli.evaluate: FCN Performance Evaluation -------------------------- -Evaluation takes as input a PyTorch_ model and generates analysis information. +Evaluation takes inference results and compares it to ground-truth, generating +a series of analysis figures which are useful to understand model performance. .. command-output:: bob binseg evaluate --help diff --git a/doc/conf.py b/doc/conf.py index 2017d7b2..c822aa64 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -256,26 +256,3 @@ intersphinx_mapping['click'] = ('https://click.palletsprojects.com/en/%s/' % (cl # Add our private index (for extras and fixes) intersphinx_mapping['extras'] = ('', 'extras.inv') - -# We want to remove all private (i.e. _. or __.__) members -# that are not in the list of accepted functions -accepted_private_functions = ['__array__'] - - -def member_function_test(app, what, name, obj, skip, options): - # test if we have a private function - if len(name) > 1 and name[0] == '_': - # test if this private function should be allowed - if name not in accepted_private_functions: - # omit privat functions that are not in the list of accepted private functions - return skip - else: - # test if the method is documented - if not hasattr(obj, '__doc__') or not obj.__doc__: - return skip - return False - - -def setup(app): - app.connect('autodoc-skip-member', member_function_test) - diff --git a/doc/datasets.rst b/doc/datasets.rst index 54b45744..92f74bf9 100644 --- a/doc/datasets.rst +++ b/doc/datasets.rst @@ -148,43 +148,4 @@ to generate iterators for training and testing. - 400 -Folder-based Dataset --------------------- - -For quick experimentation, we also provide a PyTorch_ class that works with the -following dataset folder structure for images and ground-truth (gt): - -.. code-block:: text - - root - |- images - |- gt - - -The file names should have the same stem. Currently, all image formats that can -be read via PIL are supported. Additionally, we also support HDF5 binary -files. - -For training, a new dataset configuration needs to be created. You can copy the -template :py:mod:`bob.ip.binseg.configs.datasets.imagefolder` and amend it -accordingly, e.g. to point to the the full path of the dataset and if necessary -any preprocessing steps such as resizing, cropping, padding etc. - -Training can then be started with, e.g.: - -.. code-block:: sh - - bob binseg train M2UNet /path/to/myimagefolderconfig.py -b 4 -d cuda -o /my/output/path -vv - -Similary for testing, a test dataset config needs to be created. You can copy -the template :py:mod:`bob.ip.binseg.configs.datasets.imagefoldertest` and amend -it accordingly. - -Testing can then be started with, e.g.: - -.. code-block:: bash - - bob binseg test M2UNet /path/to/myimagefoldertestconfig.py -b 2 -d cuda -o /my/output/path -vv - - .. include:: links.rst diff --git a/doc/evaluation.rst b/doc/evaluation.rst index 8bc91d0a..3b0f758d 100644 --- a/doc/evaluation.rst +++ b/doc/evaluation.rst @@ -2,9 +2,59 @@ .. _bob.ip.binseg.eval: -============ - Evaluation -============ +========================== + Inference and Evaluation +========================== + + +Inference +--------- + +You may use one of your trained models (or :ref:`one of ours +<bob.ip.binseg.models>` to run inference on existing datasets or your own +dataset. + + +Inference on an existing datasets +================================= + +To run inference, use the sub-command :ref:`predict +<bob.ip.binseg.cli.predict>` to run prediction on an existing dataset: + +.. code-block:: sh + + $ bob binseg predict -vv <model> -w <path/to/model.pth> <dataset> + + +Replace ``<model>`` and ``<dataset>`` by the appropriate :ref:`configuration +files <bob.ip.binseg.configs>`. Replace ``<path/to/model.pth>`` to a path +leading to the pre-trained model, or URL pointing to a pre-trained model (e.g. +:ref:`one of ours <bob.ip.binseg.models>`). + + +Inference on a custom dataset +============================= + +If you would like to test your own data against one of the pre-trained models, +you need to instantiate one of: + +* :py:mod:`A CSV-based configuration <bob.ip.binseg.configs.datasets.csv>` +* :py:mod:`A folder-based configuration <bob.ip.binseg.configs.datasets.folder>` + +Read the appropriate module documentation for details. + + +.. code-block:: bash + + $ bob binseg config copy folder-dataset-example mydataset.py + # or + $ bob binseg config copy csv-dataset-example mydataset.py + # edit mydataset.py to your liking + $ bob binseg predict -vv <model> -w <path/to/model.pth> ./mydataset.py + + +Evaluation +---------- To evaluate trained models use our CLI interface. ``bob binseg evaluate`` followed by the model and the dataset configuration, and the path to the @@ -49,19 +99,6 @@ The inference run generates the following output files: └── Times.txt # inference times -Inference Only Mode -==================== - -If you wish to run inference only on a folder containing images, use the -``predict`` function in combination with a -:py:mod:`bob.ip.binseg.configs.datasets.imagefolderinference` config. E.g.: - -.. code-block:: bash - - bob binseg predict M2UNet /path/to/myinferencedatasetconfig.py -b 1 -d cpu -o /my/output/path -w /path/to/pretrained/weight/model_final.pth -vv - - - To run evaluation of pretrained models pass url as ``-w`` argument. E.g.: .. code-block:: bash diff --git a/doc/training.rst b/doc/training.rst index db9daed6..7096bb8a 100644 --- a/doc/training.rst +++ b/doc/training.rst @@ -22,6 +22,7 @@ command-line options. Use ``bob binseg train --help`` for more information. Depending on the available GPU memory you might have to adjust your batch size (``--batch``). + Baseline Benchmarks =================== @@ -147,3 +148,20 @@ card, for semi-supervised learning of COVD- systems. Use it like this: - 2 - 2 - 2 + + +Using your own dataset +====================== + +To use your own dataset, we recommend you read our instructions at +:py:mod:`bob.ip.binseg.configs.datasets.csv`, and setup a CSV file describing +input data and ground-truth (segmentation maps). Then, prepare a configuration +file by copying our configuration example and edit it to apply the required +transforms to your input data. Once you are happy with the result, use it in +place of one of our datasets: + +.. code-block:: sh + + $ bob binseg config copy csv-dataset-example mydataset.py + # edit mydataset following instructions + $ bob binseg train ... mydataset.py ... diff --git a/requirements.txt b/requirements.txt index 82b7843a..c285bd22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ matplotlib tqdm tabulate bob.core +bob.io.base diff --git a/setup.py b/setup.py index 3fe6be37..452c77c0 100644 --- a/setup.py +++ b/setup.py @@ -34,10 +34,10 @@ setup( "compare = bob.bin.binseg.script.binseg:compare", "evalpred = bob.ip.binseg.script.binseg:evalpred", "gridtable = bob.ip.binseg.script.binseg:testcheckpoints", - "predict = bob.ip.binseg.script.binseg:predict", "visualize = bob.ip.binseg.script.binseg:visualize", "config = bob.ip.binseg.script.config:config", "train = bob.ip.binseg.script.train:train", + "predict = bob.ip.binseg.script.predict:predict", "evaluate = bob.ip.binseg.script.evaluate:evaluate", ], # bob train configurations @@ -56,7 +56,8 @@ setup( "resunet = bob.ip.binseg.configs.models.resunet", # datasets - "imagefolder = bob.ip.binseg.configs.datasets.imagefolder", + "csv-dataset-example = bob.ip.binseg.configs.datasets.csv", + "folder-dataset-example = bob.ip.binseg.configs.datasets.folder", # drive dataset (numbers represent target resolution) "drive = bob.ip.binseg.configs.datasets.drive", -- GitLab