diff --git a/src/mednet/libs/segmentation/config/models/lwnet.py b/src/mednet/libs/segmentation/config/models/lwnet.py index 9cb08ea3b463ed79c9b2c8c4dbc670eae7b9fa64..09af9f842c731bcaa23e284eb292d90134664883 100644 --- a/src/mednet/libs/segmentation/config/models/lwnet.py +++ b/src/mednet/libs/segmentation/config/models/lwnet.py @@ -23,4 +23,5 @@ model = LittleWNet( optimizer_type=Adam, optimizer_arguments=dict(lr=max_lr), augmentation_transforms=[], + crop_size=544, ) diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index 6cad67c925c5e9ff0a38527ba328b36242b62ab3..d8e5b2cbd23119eb3d02cd5d2ac9e9342ac3ee88 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -260,6 +260,8 @@ class LittleWNet(pl.LightningModule): applied on the input **before** it is fed into the network. num_classes Number of outputs (classes) for this model. + crop_size + The size of the image after center cropping. """ def __init__( @@ -270,13 +272,14 @@ class LittleWNet(pl.LightningModule): optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], num_classes: int = 1, + crop_size: int = 544, ): super().__init__() self.name = "lwnet" self.num_classes = num_classes - self.model_transforms = [CenterCrop(size=(544, 544))] + self.model_transforms = [CenterCrop(size=(crop_size, crop_size))] self._train_loss = train_loss self._validation_loss = (