Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
Commits on Source (28)
Showing
with 635 additions and 68 deletions
......@@ -10,6 +10,7 @@ include:
variables:
GIT_SUBMODULE_STRATEGY: normal
GIT_SUBMODULE_DEPTH: 1
XDG_CONFIG_HOME: $CI_PROJECT_DIR/tests/data
documentation:
before_script:
......
......@@ -21,5 +21,6 @@ Files:
tests/data/*.csv
tests/data/*.json
tests/data/*.png
tests/*.toml
Copyright: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
License: GPL-3.0-or-later
......@@ -63,7 +63,7 @@ def _overlay_saliency_map(
Parameters
----------
image
The input image that will be overlayed with the saliency map.
The input image that will be overlaid with the saliency map.
saliencies
The saliency map that will be overlaid on the (raw) image.
colormap
......@@ -118,7 +118,7 @@ def _overlay_bounding_box(
Parameters
----------
image
The input image that will be overlayed with the saliency map.
The input image that will be overlaid with the saliency map.
bbox
The bounding box to draw on the input image.
color
......@@ -149,13 +149,13 @@ def _process_sample(
saliencies: numpy.typing.NDArray[numpy.double],
ground_truth: BoundingBoxes,
) -> PIL.Image.Image:
"""Generate an overlayed representation of the original sample and saliency
"""Generate an overlaid representation of the original sample and saliency
maps.
Parameters
----------
raw_data
The raw data representing the input sample that will be overlayed with
The raw data representing the input sample that will be overlaid with
saliency maps and annotations.
saliencies
The saliency map recovered from the model, that will be imprinted on
......@@ -166,7 +166,7 @@ def _process_sample(
Returns
-------
PIL.Image.Image
An image with the original raw data overlayed with the different
An image with the original raw data overlaid with the different
elements as selected by the user.
"""
......
......@@ -75,7 +75,11 @@ class Pasa(pl.LightningModule):
self.model_transforms = [
Grayscale(),
SquareCenterPad(),
torchvision.transforms.Resize(512, antialias=True),
torchvision.transforms.Resize(
512,
antialias=True,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
]
self._train_loss = train_loss
......
......@@ -67,7 +67,6 @@ def experiment(
save_sh_command(output_folder / "command.sh")
# training
logger.info("Started training")
from .train import train
......@@ -139,3 +138,35 @@ def experiment(
)
logger.info("Ended evaluating")
logger.info("Started generating saliencies")
from .saliency.generate import generate
saliencies_gen_folder = output_folder / "gradcam" / "saliencies"
ctx.invoke(
generate,
model=model,
datamodule=datamodule,
weight=train_output_folder,
output_folder=saliencies_gen_folder,
)
logger.info("Ended generating saliencies")
logger.info("Started viewing saliencies")
from .saliency.view import view
saliencies_view_folder = output_folder / "gradcam" / "visualizations"
ctx.invoke(
view,
model=model,
datamodule=datamodule,
input_folder=saliencies_gen_folder,
output_folder=saliencies_view_folder,
)
logger.info("Ended viewing saliencies")
......@@ -23,10 +23,19 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
.. code:: sh
mednet saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json
mednet saliency interpretability -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json
""",
)
@click.option(
"--model",
"-m",
help="""A lightning module instance implementing the network architecture
(not the weights, necessarily) to be used for inference. Currently, only
supports pasa and densenet models.""",
required=True,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
......@@ -78,6 +87,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def interpretability(
model,
datamodule,
input_folder,
target_label,
......@@ -114,7 +124,7 @@ def interpretability(
from ...engine.saliency.interpretability import run
datamodule.model_transforms = []
datamodule.model_transforms = model.transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
......
......@@ -2,16 +2,14 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib
import tempfile
import typing
import zipfile
import numpy
import pytest
import tomli_w
import torch
from mednet.data.split import JSONDatabaseSplit
from mednet.data.typing import DatabaseSplit
......@@ -97,55 +95,6 @@ def temporary_basedir(tmp_path_factory):
return tmp_path_factory.mktemp("test-cli")
def pytest_sessionstart(session: pytest.Session) -> None:
"""Preset the session start to ensure the Montgomery dataset is always available.
Parameters
----------
session
The session to use.
"""
from mednet.utils.rc import load_rc
rc = load_rc()
database_dir = rc.get("datadir.montgomery")
if database_dir is not None:
# if the user downloaded it, use that copy
return
# else, we must extract the LFS component (we are likely on the CI)
archive = (
pathlib.Path(__file__).parents[0] / "data" / "lfs" / "test-database.zip"
)
assert archive.exists(), (
f"Neither datadir.montgomery is set on the global configuration, "
f"(typically ~/.config/mednet.toml), or it is possible to detect "
f"the presence of {archive}' (did you git submodule init --update "
f"this submodule?)"
)
montgomery_tempdir = tempfile.TemporaryDirectory()
rc.setdefault("datadir.montgomery", montgomery_tempdir.name)
with zipfile.ZipFile(archive) as zf:
zf.extractall(montgomery_tempdir.name)
config_filename = "mednet.toml"
with open(
os.path.join(montgomery_tempdir.name, config_filename), "wb"
) as f:
tomli_w.dump(rc.data, f)
f.flush()
os.environ["XDG_CONFIG_HOME"] = montgomery_tempdir.name
# stash the newly created temporary directory so we can erase it when the
key = pytest.StashKey[tempfile.TemporaryDirectory]()
session.stash[key] = montgomery_tempdir
class DatabaseCheckers:
"""Helpers for database tests."""
......@@ -156,13 +105,14 @@ class DatabaseCheckers:
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
):
"""Run a simple consistence check on the data split.
"""Run a simple consistency check on the data split.
Parameters
----------
split
An instance of DatabaseSplit.
lengths
A dictionary that contains keys matching those of the split (this will
be checked). The values of the dictionary should correspond to the
sizes of each of the datasets in the split.
......@@ -197,12 +147,13 @@ class DatabaseCheckers:
color_planes: int,
prefixes: typing.Sequence[str],
possible_labels: typing.Sequence[int],
expected_num_labels: int,
expected_image_shape: typing.Optional[tuple[int, ...]] = None,
):
"""Check the consistence of an individual (loaded) batch.
"""Check the consistency of an individual (loaded) batch.
Parameters
----------
batch
The loaded batch to be checked.
batch_size
......@@ -214,14 +165,22 @@ class DatabaseCheckers:
prefixes.
possible_labels
These are the list of possible labels contained in any split.
expected_num_labels
The expected number of labels each sample should have.
expected_image_shape
The expected shape of the image (num_channels, width, height).
"""
assert len(batch) == 2 # data, metadata
assert isinstance(batch[0], torch.Tensor)
assert batch[0].shape[0] == batch_size # mini-batch size
assert batch[0].shape[1] == color_planes # grayscale images
assert batch[0].shape[2] == batch[0].shape[3] # image is square
assert batch[0].shape[1] == color_planes
if expected_image_shape:
assert all(
[data.shape == expected_image_shape for data in batch[0]]
)
assert isinstance(batch[1], dict) # metadata
assert len(batch[1]) == 2 # label and name
......@@ -229,6 +188,9 @@ class DatabaseCheckers:
assert "label" in batch[1]
assert all([k in possible_labels for k in batch[1]["label"]])
if expected_num_labels:
assert len(batch[1]["label"]) == expected_num_labels
assert "name" in batch[1]
assert all(
[any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
......@@ -239,6 +201,56 @@ class DatabaseCheckers:
# to_pil_image(batch[0][0]).show()
# __import__("pdb").set_trace()
@staticmethod
def check_image_quality(
datamodule,
reference_histogram_file,
compare_type="equal",
pearson_coeff_threshold=0.005,
):
ref_histogram_splits = JSONDatabaseSplit(reference_histogram_file)
for split_name in ref_histogram_splits:
raw_samples = datamodule.splits[split_name][0][0]
# It is not possible to get a sample from a Dataset by name/path, only by index.
# This creates a dict of sample name to dataset index.
raw_samples_indices = {}
for idx, rs in enumerate(raw_samples):
raw_samples_indices[rs[0]] = idx
for ref_hist_path, ref_hist_data in ref_histogram_splits[
split_name
]:
# Get index in the dataset that will return the data corresponding to the specified sample name
dataset_sample_index = raw_samples_indices[ref_hist_path]
image_tensor = datamodule._datasets[split_name][
dataset_sample_index
][0]
histogram = []
for color_channel in image_tensor:
color_channel = numpy.multiply(
color_channel.numpy(), 255
).astype(int)
histogram.extend(
numpy.histogram(
color_channel, bins=256, range=(0, 256)
)[0].tolist()
)
if compare_type == "statistical":
# Compute pearson coefficients between histogram and reference
# and check the similarity within a certain threshold
pearson_coeffs = numpy.corrcoef(histogram, ref_hist_data)
assert (
1 - pearson_coeff_threshold <= pearson_coeffs[0][1] <= 1
)
else:
assert histogram == ref_hist_data
@pytest.fixture
def database_checkers():
......
This diff is collapsed.
{
"train": [
["43/216840111366964012558082906712009341094009402_00-084-033.png", [10737, 236, 243, 283, 266, 289, 465, 564, 527, 609, 596, 713, 675, 602, 698, 668, 654, 581, 619, 770, 717, 868, 798, 877, 914, 752, 975, 1168, 1238, 1515, 1247, 1446, 1658, 1754, 1853, 2385, 2756, 2908, 3404, 4017, 4205, 5651, 6197, 7375, 8072, 8391, 10084, 11275, 11945, 11665, 14699, 15124, 14660, 16467, 15130, 17940, 18101, 16597, 19790, 19148, 18556, 19904, 18454, 21009, 20369, 21128, 20932, 18738, 21558, 18474, 21116, 19820, 17114, 19056, 16118, 17543, 17911, 15672, 17568, 17190, 16721, 14801, 16232, 16359, 14084, 15901, 16576, 14584, 16598, 15642, 17028, 17791, 15826, 17573, 16294, 17610, 18394, 18038, 18514, 17829, 19453, 17992, 20473, 20042, 18745, 19884, 17932, 20704, 19903, 18587, 19957, 19659, 20286, 17691, 20508, 20343, 18544, 21237, 20576, 18031, 20049, 16705, 18796, 17822, 15893, 18329, 15456, 19854, 15521, 17042, 16972, 14350, 15526, 14155, 15253, 15034, 12899, 14015, 13790, 11762, 13002, 11305, 12491, 13047, 12793, 11799, 13769, 13376, 12398, 13218, 12974, 11865, 12523, 11408, 12257, 12249, 11243, 11988, 12095, 11481, 10329, 11479, 11115, 9501, 10360, 8846, 9883, 9548, 8501, 9342, 8987, 8140, 9050, 7703, 8985, 8216, 8057, 7508, 7951, 8016, 6715, 7415, 7453, 6199, 6911, 5830, 6213, 6309, 5453, 5810, 5954, 5670, 5127, 5351, 5424, 4955, 5369, 5059, 5279, 5411, 5031, 5315, 5650, 4873, 5498, 5270, 5640, 5486, 5730, 5572, 4992, 5342, 4942, 5538, 5395, 4951, 5190, 4634, 5486, 5224, 4576, 5273, 5068, 5234, 4510, 4987, 5185, 4438, 4973, 4840, 4249, 4828, 4163, 4632, 4865, 4272, 4715, 4183, 4570, 4888, 4572, 4816, 4089, 4526, 4331, 4587, 4753, 4164, 4591, 4790, 4206, 4505, 4232, 4626, 112561, 0]],
["8/333859634267203948016118549007441698534_uwr29s.png", [197384, 3762, 1369, 1334, 1395, 1595, 1926, 2249, 2532, 2644, 2897, 2468, 2378, 2035, 2259, 2175, 2494, 2334, 2562, 2585, 2871, 2968, 3296, 3247, 3240, 3256, 3410, 3361, 3281, 3514, 3651, 3588, 3806, 4145, 4640, 4984, 5552, 6320, 6891, 7834, 8292, 9131, 10701, 10698, 12503, 12915, 15505, 15069, 17255, 17475, 18302, 19504, 19210, 21324, 21976, 21845, 22810, 23324, 25640, 23114, 25154, 23772, 27263, 24559, 26156, 24730, 26205, 24818, 24750, 24299, 26058, 23064, 24756, 23025, 24695, 25177, 23767, 23694, 25906, 24191, 26519, 24552, 24639, 25223, 25472, 24044, 26070, 26329, 26674, 25620, 27633, 28035, 26682, 26914, 27654, 30294, 30761, 28871, 28868, 31199, 27313, 29155, 29075, 28828, 28954, 28719, 26139, 27992, 27505, 25392, 27397, 27167, 26840, 25118, 24682, 28713, 25122, 24943, 25295, 27756, 25600, 26463, 26850, 27024, 28139, 28463, 29407, 32209, 31149, 32024, 33300, 34006, 35035, 33281, 36901, 38055, 38863, 40109, 38213, 42411, 43246, 40821, 44952, 49059, 42409, 42841, 47471, 47406, 43626, 44150, 44470, 49013, 49121, 45831, 46707, 50751, 47165, 47459, 47735, 55790, 47982, 48246, 48090, 52350, 47494, 47202, 46504, 45508, 44375, 43414, 42542, 42194, 41426, 41280, 40815, 44417, 40793, 41086, 41150, 41719, 41787, 42668, 43055, 39754, 47335, 40783, 44309, 44699, 44983, 45223, 41487, 48810, 45196, 45483, 41849, 45760, 45817, 45715, 42312, 46541, 46199, 46831, 44218, 48151, 47783, 47317, 43379, 46502, 49218, 39854, 38440, 39564, 36323, 32955, 27414, 24583, 25731, 18628, 16098, 14629, 12159, 8972, 8024, 6479, 5241, 3799, 2835, 2591, 1959, 1435, 1156, 822, 683, 459, 389, 304, 191, 126, 105, 71, 33, 21, 12, 9, 6, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
["1/308671648074000163114167886553150666431_2pac7b.png", [415725, 13365, 13405, 13074, 13057, 12859, 12510, 11188, 9033, 7554, 7159, 7240, 7207, 7190, 7471, 7818, 8271, 7439, 7543, 7349, 7379, 6975, 6660, 6544, 6835, 7064, 7071, 7279, 7295, 7467, 7740, 7915, 8942, 8934, 8867, 9090, 9398, 9390, 9700, 9960, 10418, 10882, 10917, 11266, 11461, 11361, 11656, 11952, 13485, 13174, 14111, 14915, 15179, 16118, 17246, 18101, 18917, 20147, 21127, 22264, 23308, 24464, 25080, 26365, 29050, 28459, 29950, 30309, 31724, 32143, 33019, 33685, 34455, 35596, 35374, 36142, 36872, 37187, 37036, 37781, 40134, 37666, 38461, 38311, 35561, 38704, 38455, 38893, 41009, 38520, 38516, 38387, 37630, 36877, 36713, 35997, 35313, 34617, 34526, 33365, 33124, 33245, 32861, 33154, 35506, 32917, 33201, 33059, 32954, 33228, 33461, 33694, 33429, 33863, 33884, 33993, 33256, 34315, 33293, 34024, 36341, 33891, 34925, 34970, 35946, 35969, 36549, 36860, 37578, 38166, 38764, 39110, 39932, 40924, 41529, 45547, 42360, 43596, 43185, 43975, 44099, 44909, 45731, 46346, 47064, 47151, 47223, 47559, 47686, 48352, 48266, 51367, 50013, 50471, 50838, 52617, 53091, 55417, 55411, 57044, 59114, 60543, 61597, 62692, 62779, 65087, 65925, 69340, 65062, 61217, 63925, 63177, 61967, 60985, 57523, 61101, 55443, 54322, 52669, 52419, 51823, 50422, 51314, 50217, 50877, 49093, 49246, 49808, 48752, 48014, 47755, 49096, 46330, 45417, 44649, 44399, 43232, 43687, 43457, 42807, 44149, 43206, 44650, 45224, 45816, 47155, 47182, 51341, 48003, 47250, 46581, 46459, 44426, 42731, 40682, 38041, 35549, 33123, 30073, 27003, 23769, 21747, 18976, 17217, 13738, 11694, 9882, 8089, 6748, 5188, 4225, 3191, 2474, 1835, 1281, 953, 668, 469, 287, 214, 109, 68, 34, 12, 7, 3, 3, 1, 0, 0, 0, 0, 0, 0, 731, 0]],
["18/216840111366964012810946289282010224084149790_02-122-189.png", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 230289, 18887, 17066, 41210, 30652, 32960, 20785, 18608, 15907, 16909, 15839, 10107, 21932, 15201, 11076, 13022, 11474, 12948, 11910, 5668, 13725, 12070, 6150, 13745, 8910, 9925, 10784, 9332, 6540, 3629, 18143, 3408, 15974, 3905, 10922, 7430, 13417, 4258, 8085, 11817, 13965, 4495, 8446, 12800, 9509, 5656, 9805, 4774, 15638, 11945, 7117, 19736, 7231, 16374, 30465, 0, 20772, 34906, 13476, 30847, 33750, 15981, 33631, 37916, 20491, 47676, 22205, 21555, 43515, 46025, 50788, 30304, 51577, 24318, 49358, 52913, 28530, 60081, 54803, 27246, 57173, 61877, 36206, 31595, 92530, 30931, 65762, 75760, 34673, 0, 67676, 66797, 34978, 72941, 79296, 36031, 68811, 34375, 72694, 38898, 82268, 73556, 36607, 75563, 82037, 90861, 42007, 41394, 82027, 88214, 47650, 102765, 93194, 46871, 47943, 104943, 58135, 127917, 111671, 111326, 58879, 127323, 201914, 62101, 123191, 130999, 69929, 151832, 129980, 62629, 192061, 68527, 213370, 63911, 62974, 62389, 202724, 81890, 204539, 128085, 133249, 152359, 68340, 129314, 181233, 60688, 194370, 59177, 163140, 107507, 119867, 108043, 99424, 47698, 146249, 104954, 90618, 84496, 128586, 92688, 80290, 113857, 78824, 85740, 110984, 72060, 74244, 116923, 69194, 68090, 72569, 78101, 69700, 140514, 38547, 119647, 106158, 112682, 46489, 110995, 99749, 68850, 107773, 87431, 55629, 89837, 27047, 71928, 46976, 76268, 43984, 81128, 67245, 39680, 76440, 20897, 65079, 57063, 63547, 48647, 61655, 41493, 72708, 86208, 19796, 64448, 89527, 58164, 42593, 68758, 57594, 61684, 68091, 56611, 219333, 191372, 222988, 135915, 315853, 162814, 340543, 339464, 136999, 17744, 0, 0, 0, 0, 0, 0, 0, 0]],
["16/216840111366964012809176623042010216143036077_02-118-034.png", [466312, 1995, 1878, 2065, 2468, 2506, 2484, 2546, 2604, 2547, 2594, 2808, 2856, 3090, 3184, 3140, 3651, 3155, 3195, 3363, 3453, 3427, 3567, 3531, 3824, 3655, 3663, 3708, 3790, 3729, 3575, 3816, 4133, 3808, 4060, 4181, 4516, 4369, 4395, 4506, 4603, 4566, 4531, 4687, 4400, 4673, 4985, 5361, 6445, 5619, 5790, 6135, 6685, 7300, 7776, 9107, 9119, 9667, 10070, 10862, 11680, 12225, 13394, 13342, 15352, 14940, 16000, 17809, 17741, 18297, 18472, 18212, 18262, 17938, 18481, 17584, 17567, 17602, 17761, 17813, 20835, 18751, 18860, 18374, 17256, 16899, 17129, 16289, 16438, 16915, 17307, 17046, 17496, 18361, 17883, 18053, 20222, 18232, 19346, 18085, 18509, 18410, 18545, 18274, 18264, 19304, 18569, 18671, 18631, 18631, 18430, 18688, 18499, 19680, 21060, 19014, 19513, 20314, 19338, 18864, 18805, 18620, 18620, 18261, 18100, 19278, 18481, 18879, 18737, 18897, 19071, 18767, 19846, 19118, 21014, 19159, 19157, 19990, 19691, 20094, 20180, 20028, 20723, 20917, 21887, 20965, 21229, 21178, 21026, 21344, 21162, 21857, 21169, 20778, 22339, 20196, 20734, 19580, 18974, 18699, 18624, 17921, 17795, 18351, 17489, 17624, 17461, 17578, 17440, 17525, 17002, 17530, 16474, 15945, 17505, 16346, 15685, 15737, 15563, 15580, 15430, 15215, 15179, 15626, 14761, 14600, 14604, 14489, 14704, 14595, 15343, 16139, 15094, 15168, 15103, 15938, 15224, 14972, 14764, 14389, 13997, 13878, 14197, 13761, 13211, 13297, 13195, 15087, 13303, 13380, 13483, 13542, 14013, 13848, 14683, 14172, 14280, 14548, 14699, 14705, 14711, 15554, 14876, 16410, 14830, 14517, 15229, 14578, 14454, 14248, 14071, 13584, 13343, 13169, 12089, 11446, 10719, 9796, 9571, 9791, 8004, 7492, 7002, 6533, 6273, 6359, 5839, 5659, 5511, 5290, 5158, 5037, 4873, 4796, 131862, 0]]
]
}
This diff is collapsed.
Subproject commit 05344d20182ad4169fc5b9c38052d629aded30ed
Subproject commit fe158e393b7904c6b381b0d9862895bf05c6cf7a