diff --git a/src/mednet/data/segment/shenzhen.py b/src/mednet/data/segment/shenzhen.py index 02570d7bd9a6b1af761a5862dc49d7dc31475803..96f073a284e4e58bb701a4f32573b0de6e873bbf 100644 --- a/src/mednet/data/segment/shenzhen.py +++ b/src/mednet/data/segment/shenzhen.py @@ -67,7 +67,6 @@ import torch from torchvision import tv_tensors from torchvision.transforms.v2.functional import to_dtype, to_image -from ...models.transforms import crop_image_to_mask from ...utils.rc import load_rc from ..datamodule import CachingDataModule from ..split import JSONDatabaseSplit @@ -116,8 +115,8 @@ class RawDataLoader(BaseDataLoader): # use image as a base since target() can be overriden by child class mask = torch.ones((1, image.shape[-2], image.shape[-1]), dtype=torch.float32) - image = tv_tensors.Image(crop_image_to_mask(image, mask)) - target = tv_tensors.Mask(crop_image_to_mask(target, mask)) + image = tv_tensors.Image(image) + target = tv_tensors.Mask(target) mask = tv_tensors.Mask(mask) return dict(image=image, target=target, mask=mask, name=sample[0])