Skip to content
Snippets Groups Projects
Commit a9f7865a authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.classify] Replace model_ft -> model

parent a504c37c
No related branches found
No related tags found
1 merge request!64Add object detection
......@@ -6,11 +6,12 @@
import logging
import typing
import lightning.pytorch
import lightning.pytorch.callbacks
import torch.utils.data
from ...engine.device import DeviceManager
from ...models.classify.typing import Prediction, PredictionSplit
from ...utils.string import rewrap
logger = logging.getLogger(__name__)
......@@ -63,6 +64,8 @@ class _JSONMetadataCollector(lightning.pytorch.callbacks.BasePredictionWriter):
dataloader_idx
Index of the dataloader overall.
"""
del trainer, pl_module, batch_indices, batch_idx, dataloader_idx
for k, sample_pred in enumerate(prediction):
sample_name: str = batch["name"][k]
target_shape = batch["target"][k].shape
......@@ -100,7 +103,7 @@ def run(
model
Neural network model (e.g. pasa).
datamodule
The lightning DataModule to use for training **and** validation.
The lightning DataModule to run predictions on.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
......@@ -170,7 +173,9 @@ def run(
# if you get to this point, then the user is returning something that is
# not supported - complain!
raise TypeError(
f"Datamodule returned strangely typed prediction "
f"dataloaders: `{type(dataloaders)}` - Please write code "
f"to support this use-case.",
rewrap(
f"""Datamodule returned strangely typed prediction dataloaders:
`{type(dataloaders)}` - if this is not an error, write code to support this
use-case."""
)
)
......@@ -310,7 +310,7 @@ def run(
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet):
target_layers = [
model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore
model.model.features.denseblock4.denselayer16.conv2, # type: ignore
]
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
......
......@@ -193,7 +193,7 @@ def run(
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet):
target_layers = [
model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore
model.model.features.denseblock4.denselayer16.conv2, # type: ignore
]
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
......
......@@ -94,13 +94,13 @@ class Alexnet(Model):
logger.info(f"Loading pretrained `{self.name}` model weights")
weights = models.AlexNet_Weights.DEFAULT
self.model_ft = models.alexnet(weights=weights)
self.model = models.alexnet(weights=weights)
self.model_ft.classifier[4] = torch.nn.Linear(
in_features=self.model_ft.classifier[1].out_features, out_features=512
self.model.classifier[4] = torch.nn.Linear(
in_features=self.model.classifier[1].out_features, out_features=512
)
self.model_ft.classifier[6] = torch.nn.Linear(
in_features=self.model_ft.classifier[4].out_features,
self.model.classifier[6] = torch.nn.Linear(
in_features=self.model.classifier[4].out_features,
out_features=self.num_classes,
)
......@@ -111,14 +111,14 @@ class Alexnet(Model):
f"Resetting `{self.name}` output classifier layer weights due "
f"to a change in output size ({self.num_classes} -> {v})"
)
self.model_ft.classifier[6] = torch.nn.Linear(
in_features=self.model_ft.classifier[4].out_features, out_features=v
self.model.classifier[6] = torch.nn.Linear(
in_features=self.model.classifier[4].out_features, out_features=v
)
self._num_classes = v
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
# reset number of output classes if need be
self.num_classes = checkpoint["state_dict"]["model_ft.classifier.bias"].shape[0]
self.num_classes = checkpoint["state_dict"]["model.classifier.bias"].shape[0]
# perform routine checkpoint loading
super().on_load_checkpoint(checkpoint)
......@@ -136,4 +136,4 @@ class Alexnet(Model):
The prediction, as a tensor.
"""
x = self.normalizer(x)
return self.model_ft(x)
return self.model(x)
......@@ -96,11 +96,11 @@ class Densenet(Model):
logger.info(f"Loading pretrained `{self.name}` model weights")
weights = models.DenseNet121_Weights.DEFAULT
self.model_ft = models.densenet121(weights=weights, drop_rate=self.dropout)
self.model = models.densenet121(weights=weights, drop_rate=self.dropout)
# output layer
self.model_ft.classifier = torch.nn.Linear(
self.model_ft.classifier.in_features, self.num_classes
self.model.classifier = torch.nn.Linear(
self.model.classifier.in_features, self.num_classes
)
@Model.num_classes.setter # type: ignore[attr-defined]
......@@ -110,18 +110,26 @@ class Densenet(Model):
f"Resetting `{self.name}` output classifier layer weights due "
f"to a change in output size ({self.num_classes} -> {v})"
)
self.model_ft.classifier = torch.nn.Linear(
self.model_ft.classifier.in_features, v
self.model.classifier = torch.nn.Linear(
self.model.classifier.in_features, v
)
self._num_classes = v
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
# support previous version of densenet (model_ft -> model)
if any([k.startswith("model_ft") for k in checkpoint["state_dict"].keys()]):
# convert all "model_ft" entries to "model"
checkpoint["state_dict"] = {
k.replace("model_ft", "model"): v
for k, v in checkpoint["state_dict"].items()
}
# reset number of output classes if need be
self.num_classes = checkpoint["state_dict"]["model_ft.classifier.bias"].shape[0]
self.num_classes = checkpoint["state_dict"]["model.classifier.bias"].shape[0]
# perform routine checkpoint loading
super().on_load_checkpoint(checkpoint)
def forward(self, x):
x = self.normalizer(x)
return self.model_ft(x)
return self.model(x)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment