logger.info("The number of training samples: {}".format(dataset.__len__()))
ifcross_validate:# if cross validation is enabled:
logger.info("The number of cross-validation samples: {}".format(dataset_dev.__len__()))
dataloader=DataLoader(dataset,
batch_size=config_module.BATCH_SIZE,
shuffle=True)
ifcross_validate:# if cross validation is enabled:
dataloader_dev=DataLoader(dataset_dev,
batch_size=config_module.BATCH_SIZE,
shuffle=False)# shuffling is not needed in cross-validation
UNUSED=dataset.__getitem__(0)# call a dataset __getitem__ once, to **possibly** compute normalization parameters, after that num_workers can be set for dataloader
if"NUM_WORKERS"indir(config_module)anddataloader.num_workers==0:# set the number of workers for the DataLoader
dataloader.num_workers=config_module.NUM_WORKERS
ifverbosity>0:
logger.info("The number of workers for the DataLoader is: {}".format(dataloader.num_workers))