diff --git a/src/mednet/data/segment/shenzhen.py b/src/mednet/data/segment/shenzhen.py
index 02570d7bd9a6b1af761a5862dc49d7dc31475803..96f073a284e4e58bb701a4f32573b0de6e873bbf 100644
--- a/src/mednet/data/segment/shenzhen.py
+++ b/src/mednet/data/segment/shenzhen.py
@@ -67,7 +67,6 @@ import torch
 from torchvision import tv_tensors
 from torchvision.transforms.v2.functional import to_dtype, to_image
 
-from ...models.transforms import crop_image_to_mask
 from ...utils.rc import load_rc
 from ..datamodule import CachingDataModule
 from ..split import JSONDatabaseSplit
@@ -116,8 +115,8 @@ class RawDataLoader(BaseDataLoader):
         # use image as a base since target() can be overriden by child class
         mask = torch.ones((1, image.shape[-2], image.shape[-1]), dtype=torch.float32)
 
-        image = tv_tensors.Image(crop_image_to_mask(image, mask))
-        target = tv_tensors.Mask(crop_image_to_mask(target, mask))
+        image = tv_tensors.Image(image)
+        target = tv_tensors.Mask(target)
         mask = tv_tensors.Mask(mask)
 
         return dict(image=image, target=target, mask=mask, name=sample[0])