Skip to content
Snippets Groups Projects
Commit 684bd8ce authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[tbx11k] Add parameter to avoid returning bounding boxes

parent 0d2f0163
No related branches found
No related tags found
1 merge request!55Fix montgomery-shenzhen-indian-tbx11k datamodule
Pipeline #89098 failed
......@@ -40,7 +40,7 @@ class DataModule(ConcatDataModule):
indian_loader = IndianLoader(INDIAN_KEY_DATADIR)
indian_package = IndianDataModule.__module__.rsplit(".", 1)[0]
indian_split = make_split(indian_package, split_filename)
tbx11k_loader = TBX11kLoader()
tbx11k_loader = TBX11kLoader(ignore_bboxes=True)
tbx11k_package = TBX11kLoader.__module__.rsplit(".", 1)[0]
tbx11k_split = make_split(tbx11k_package, tbx11k_split_filename)
......
......@@ -168,19 +168,26 @@ finding locations, as described above.
class RawDataLoader(_BaseRawDataLoader):
"""A specialized raw-data-loader for the TBX11k dataset."""
"""A specialized raw-data-loader for the TBX11k dataset.
Parameters
----------
ignore_bboxes
If True, sample() does not return bounding boxes.
"""
datadir: pathlib.Path
"""This variable contains the base directory where the database raw data is
stored."""
def __init__(self):
def __init__(self, ignore_bboxes: bool = False):
self.datadir = pathlib.Path(
load_rc().get(
CONFIGURATION_KEY_DATADIR,
os.path.realpath(os.curdir),
),
)
self.ignore_bboxes = ignore_bboxes
def sample(self, sample: DatabaseSample) -> Sample:
"""Load a single image sample from the disk.
......@@ -206,6 +213,12 @@ class RawDataLoader(_BaseRawDataLoader):
# to_pil_image(tensor).show()
# __import__("pdb").set_trace()
if self.ignore_bboxes:
return tensor, dict(
label=sample[1],
name=sample[0],
)
return tensor, dict(
label=sample[1],
name=sample[0],
......@@ -356,13 +369,15 @@ class DataModule(CachingDataModule):
----------
split_filename
Name of the .json file containing the split to load.
ignore_bboxes
If True, sample() does not return bounding boxes.
"""
def __init__(self, split_filename: str):
def __init__(self, split_filename: str, ignore_bboxes: bool = False):
assert __package__ is not None
super().__init__(
database_split=make_split(__package__, split_filename),
raw_data_loader=RawDataLoader(),
raw_data_loader=RawDataLoader(ignore_bboxes=ignore_bboxes),
database_name=__package__.rsplit(".", 1)[1],
split_name=pathlib.Path(split_filename).stem,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment