diff --git a/src/mednet/libs/common/scripts/database.py b/src/mednet/libs/common/scripts/database.py index 07c6f1f8d63fbfafa082f7a9cba2e834b2158f99..7bb561dc633f071306342a747167164bfdb19cd4 100644 --- a/src/mednet/libs/common/scripts/database.py +++ b/src/mednet/libs/common/scripts/database.py @@ -75,10 +75,16 @@ def check(entry_point_group, fold, limit): # numpydoc ignore=PR01 for i, batch in enumerate(loader): if loader_limit == 0: break - logger.info( - f"{batch[1]['name'][0]}: " - f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}", - ) + if isinstance(batch[0], dict): + logger.info( + f"{batch[1]['name'][0]}: " + f"{[s for s in batch[0]['image'][0].shape]}@{batch[0]['image'][0].dtype}", + ) + else: + logger.info( + f"{batch[1]['name'][0]}: " + f"{[s for s in batch[0][0].shape]}@{batch[0][0].dtype}", + ) loader_limit -= 1 except Exception: logger.exception(f"Unable to load batch {i} in dataset {k}")