Skip to content
Snippets Groups Projects
Commit 1e2c50e0 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[lwnet] Add parameter for crop size

parent 129a17a1
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -23,4 +23,5 @@ model = LittleWNet( ...@@ -23,4 +23,5 @@ model = LittleWNet(
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=max_lr), optimizer_arguments=dict(lr=max_lr),
augmentation_transforms=[], augmentation_transforms=[],
crop_size=544,
) )
...@@ -260,6 +260,8 @@ class LittleWNet(pl.LightningModule): ...@@ -260,6 +260,8 @@ class LittleWNet(pl.LightningModule):
applied on the input **before** it is fed into the network. applied on the input **before** it is fed into the network.
num_classes num_classes
Number of outputs (classes) for this model. Number of outputs (classes) for this model.
crop_size
The size of the image after center cropping.
""" """
def __init__( def __init__(
...@@ -270,13 +272,14 @@ class LittleWNet(pl.LightningModule): ...@@ -270,13 +272,14 @@ class LittleWNet(pl.LightningModule):
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
crop_size: int = 544,
): ):
super().__init__() super().__init__()
self.name = "lwnet" self.name = "lwnet"
self.num_classes = num_classes self.num_classes = num_classes
self.model_transforms = [CenterCrop(size=(544, 544))] self.model_transforms = [CenterCrop(size=(crop_size, crop_size))]
self._train_loss = train_loss self._train_loss = train_loss
self._validation_loss = ( self._validation_loss = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment