Skip to content
Snippets Groups Projects

WIP: Generic trainer

Closed Anjith GEORGE requested to merge generic_trainer into master
3 files
+ 22
21
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -137,33 +137,24 @@ optimizer = optim.Adam(filter(lambda p: p.requires_grad, network.parameters()),l
def compute_loss(network,img, labels, device):
img = img.to(device)
"""
Compute the losses, given the network, data and labels and
device in which the computation will be performed.
"""
if not isinstance(labels,list):
imagesv = Variable(img['image'].to(device))
labels=[labels]
labelsv_pixel = Variable(labels['pixel_mask'].to(device))
labelsv=[]
for label in labels:
labelsv.append(label.to(device))
labels=labelsv.copy()
labelsv_binary = Variable(labels['binary_target'].to(device))
imagesv = Variable(img)
labelsv=[]
for label in labels:
labelsvt = Variable(label) # needs a list here
labelsv.append(labelsvt)
out= network(imagesv)
out = network(imagesv)
beta=0.5
loss_pixel = criterion_pixel(out[0].squeeze(1),labelsv[0].float())
loss_pixel = criterion_pixel(out[0].squeeze(1),labelsv_pixel.float())
loss_bce = criterion_bce(out[1],labelsv[1].unsqueeze(1).float())
loss_bce = criterion_bce(out[1],labelsv_binary.unsqueeze(1).float())
loss = beta*loss_bce + (1.0-beta)*loss_pixel
Loading