diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py index 2361b886d500fee740e456f2505a36da4fdaf4e3..815226b517142438d77db25e69f6f4e173cee39c 100644 --- a/src/ptbench/configs/models/alexnet.py +++ b/src/ptbench/configs/models/alexnet.py @@ -4,32 +4,17 @@ """AlexNet.""" -from torch import empty from torch.nn import BCEWithLogitsLoss from torch.optim import SGD -from ...models.alexnet import Alexnet - -# optimizer -optimizer = SGD -optimizer_configs = {"lr": 0.01, "momentum": 0.1} - -# criterion -criterion = BCEWithLogitsLoss(pos_weight=empty(1)) -criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) - from ...data.transforms import ElasticDeformation +from ...models.alexnet import Alexnet -augmentation_transforms = [ - ElasticDeformation(p=0.8), -] - -# model model = Alexnet( - criterion, - criterion_valid, - optimizer, - optimizer_configs, + train_loss=BCEWithLogitsLoss(), + validation_loss=BCEWithLogitsLoss(), + optimizer_type=SGD, + optimizer_arguments=dict(lr=0.01, momentum=0.1), + augmentation_transforms=[ElasticDeformation(p=0.8)], pretrained=False, - augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index 0dc7e5d67d007cf5e7e358e7fa75243a47047c4b..f968df50cda171cc94991febc511168d111517c9 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -4,32 +4,17 @@ """AlexNet.""" -from torch import empty from torch.nn import BCEWithLogitsLoss from torch.optim import SGD -from ...models.alexnet import Alexnet - -# optimizer -optimizer = SGD -optimizer_configs = {"lr": 0.01, "momentum": 0.1} - -# criterion -criterion = BCEWithLogitsLoss(pos_weight=empty(1)) -criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) - from ...data.transforms import ElasticDeformation +from ...models.alexnet import Alexnet -augmentation_transforms = [ - ElasticDeformation(p=0.8), -] - -# model model = Alexnet( - criterion, - criterion_valid, - optimizer, - optimizer_configs, + train_loss=BCEWithLogitsLoss(), + validation_loss=BCEWithLogitsLoss(), + optimizer_type=SGD, + optimizer_arguments=dict(lr=0.01, momentum=0.1), + augmentation_transforms=[ElasticDeformation(p=0.8)], pretrained=True, - augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py index 5d612b2a18146ba306b419a808936e9c3c7042f7..79f8f7dabc58746c1029bbc9760f10137801c202 100644 --- a/src/ptbench/configs/models/densenet.py +++ b/src/ptbench/configs/models/densenet.py @@ -4,32 +4,17 @@ """DenseNet.""" -from torch import empty from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from ...models.densenet import Densenet - -# optimizer -optimizer = Adam -optimizer_configs = {"lr": 0.0001} - -# criterion -criterion = BCEWithLogitsLoss(pos_weight=empty(1)) -criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) - from ...data.transforms import ElasticDeformation +from ...models.densenet import Densenet -augmentation_transforms = [ - ElasticDeformation(p=0.8), -] - -# model model = Densenet( - criterion, - criterion_valid, - optimizer, - optimizer_configs, + train_loss=BCEWithLogitsLoss(), + validation_loss=BCEWithLogitsLoss(), + optimizer_type=Adam, + optimizer_arguments=dict(lr=0.0001), + augmentation_transforms=[ElasticDeformation(p=0.8)], pretrained=False, - augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py index f8908fdb1e87a62df41dca0ecb75ff1fc79b1012..4bc4616c6de0a19134646a4ad1449c2920be9e50 100644 --- a/src/ptbench/configs/models/densenet_pretrained.py +++ b/src/ptbench/configs/models/densenet_pretrained.py @@ -4,32 +4,17 @@ """DenseNet.""" -from torch import empty from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from ...models.densenet import Densenet - -# optimizer -optimizer = Adam -optimizer_configs = {"lr": 0.0001} - -# criterion -criterion = BCEWithLogitsLoss(pos_weight=empty(1)) -criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) - from ...data.transforms import ElasticDeformation +from ...models.densenet import Densenet -augmentation_transforms = [ - ElasticDeformation(p=0.8), -] - -# model model = Densenet( - criterion, - criterion_valid, - optimizer, - optimizer_configs, + train_loss=BCEWithLogitsLoss(), + validation_loss=BCEWithLogitsLoss(), + optimizer_type=Adam, + optimizer_arguments=dict(lr=0.0001), + augmentation_transforms=[ElasticDeformation(p=0.8)], pretrained=True, - augmentation_transforms=augmentation_transforms, )