Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
deepdraw
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
This is an archived project. Repository and other project resources are read-only.
Show more breadcrumbs
medai
software
deepdraw
Commits
7eb7def3
Commit
7eb7def3
authored
4 years ago
by
André Anjos
Browse files
Options
Downloads
Patches
Plain Diff
[engine.ssltrainer] Re-sync with engine.trainer
parent
5899f1b5
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!12
Streamlining
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/ip/binseg/engine/ssltrainer.py
+99
-106
99 additions, 106 deletions
bob/ip/binseg/engine/ssltrainer.py
with
99 additions
and
106 deletions
bob/ip/binseg/engine/ssltrainer.py
+
99
−
106
View file @
7eb7def3
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
os
import
os
import
csv
import
time
import
time
import
datetime
import
datetime
import
torch
import
torch
...
@@ -21,7 +22,7 @@ def sharpen(x, T):
...
@@ -21,7 +22,7 @@ def sharpen(x, T):
return
temp
/
temp
.
sum
(
dim
=
1
,
keepdim
=
True
)
return
temp
/
temp
.
sum
(
dim
=
1
,
keepdim
=
True
)
def
mix_up
(
alpha
,
input
,
target
,
unlabeled_input
,
unlabled_target
):
def
mix_up
(
alpha
,
input
,
target
,
unlabel
l
ed_input
,
unlabled_target
):
"""
Applies mix up as described in [MIXMATCH_19].
"""
Applies mix up as described in [MIXMATCH_19].
Parameters
Parameters
...
@@ -32,7 +33,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
...
@@ -32,7 +33,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
target : :py:class:`torch.Tensor`
target : :py:class:`torch.Tensor`
unlabeled_input : :py:class:`torch.Tensor`
unlabel
l
ed_input : :py:class:`torch.Tensor`
unlabled_target : :py:class:`torch.Tensor`
unlabled_target : :py:class:`torch.Tensor`
...
@@ -48,17 +49,17 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
...
@@ -48,17 +49,17 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
l
=
np
.
random
.
beta
(
alpha
,
alpha
)
# Eq (8)
l
=
np
.
random
.
beta
(
alpha
,
alpha
)
# Eq (8)
l
=
max
(
l
,
1
-
l
)
# Eq (9)
l
=
max
(
l
,
1
-
l
)
# Eq (9)
# Shuffle and concat. Alg. 1 Line: 12
# Shuffle and concat. Alg. 1 Line: 12
w_inputs
=
torch
.
cat
([
input
,
unlabeled_input
],
0
)
w_inputs
=
torch
.
cat
([
input
,
unlabel
l
ed_input
],
0
)
w_targets
=
torch
.
cat
([
target
,
unlabled_target
],
0
)
w_targets
=
torch
.
cat
([
target
,
unlabled_target
],
0
)
idx
=
torch
.
randperm
(
w_inputs
.
size
(
0
))
# get random index
idx
=
torch
.
randperm
(
w_inputs
.
size
(
0
))
# get random index
# Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
# Apply MixUp to label
l
ed data and entries from W. Alg. 1 Line: 13
input_mixedup
=
l
*
input
+
(
1
-
l
)
*
w_inputs
[
idx
[
len
(
input
)
:]]
input_mixedup
=
l
*
input
+
(
1
-
l
)
*
w_inputs
[
idx
[
len
(
input
)
:]]
target_mixedup
=
l
*
target
+
(
1
-
l
)
*
w_targets
[
idx
[
len
(
target
)
:]]
target_mixedup
=
l
*
target
+
(
1
-
l
)
*
w_targets
[
idx
[
len
(
target
)
:]]
# Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
# Apply MixUp to unlabel
l
ed data and entries from W. Alg. 1 Line: 14
unlabeled_input_mixedup
=
(
unlabel
l
ed_input_mixedup
=
(
l
*
unlabeled_input
+
(
1
-
l
)
*
w_inputs
[
idx
[:
len
(
unlabeled_input
)]]
l
*
unlabel
l
ed_input
+
(
1
-
l
)
*
w_inputs
[
idx
[:
len
(
unlabel
l
ed_input
)]]
)
)
unlabled_target_mixedup
=
(
unlabled_target_mixedup
=
(
l
*
unlabled_target
+
(
1
-
l
)
*
w_targets
[
idx
[:
len
(
unlabled_target
)]]
l
*
unlabled_target
+
(
1
-
l
)
*
w_targets
[
idx
[:
len
(
unlabled_target
)]]
...
@@ -66,7 +67,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
...
@@ -66,7 +67,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
return
(
return
(
input_mixedup
,
input_mixedup
,
target_mixedup
,
target_mixedup
,
unlabeled_input_mixedup
,
unlabel
l
ed_input_mixedup
,
unlabled_target_mixedup
,
unlabled_target_mixedup
,
)
)
...
@@ -122,14 +123,14 @@ def linear_rampup(current, rampup_length=16):
...
@@ -122,14 +123,14 @@ def linear_rampup(current, rampup_length=16):
return
float
(
current
)
return
float
(
current
)
def
guess_labels
(
unlabeled_images
,
model
):
def
guess_labels
(
unlabel
l
ed_images
,
model
):
"""
"""
Calculate the average predictions by 2 augmentations: horizontal and vertical flips
Calculate the average predictions by 2 augmentations: horizontal and vertical flips
Parameters
Parameters
----------
----------
unlabeled_images : :py:class:`torch.Tensor`
unlabel
l
ed_images : :py:class:`torch.Tensor`
``[n,c,h,w]``
``[n,c,h,w]``
target : :py:class:`torch.Tensor`
target : :py:class:`torch.Tensor`
...
@@ -142,12 +143,12 @@ def guess_labels(unlabeled_images, model):
...
@@ -142,12 +143,12 @@ def guess_labels(unlabeled_images, model):
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
guess1
=
torch
.
sigmoid
(
model
(
unlabeled_images
)).
unsqueeze
(
0
)
guess1
=
torch
.
sigmoid
(
model
(
unlabel
l
ed_images
)).
unsqueeze
(
0
)
# Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
# Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
hflip
=
torch
.
sigmoid
(
model
(
unlabeled_images
.
flip
(
2
))).
unsqueeze
(
0
)
hflip
=
torch
.
sigmoid
(
model
(
unlabel
l
ed_images
.
flip
(
2
))).
unsqueeze
(
0
)
guess2
=
hflip
.
flip
(
3
)
guess2
=
hflip
.
flip
(
3
)
# Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
# Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
vflip
=
torch
.
sigmoid
(
model
(
unlabeled_images
.
flip
(
3
))).
unsqueeze
(
0
)
vflip
=
torch
.
sigmoid
(
model
(
unlabel
l
ed_images
.
flip
(
3
))).
unsqueeze
(
0
)
guess3
=
vflip
.
flip
(
4
)
guess3
=
vflip
.
flip
(
4
)
# Concat
# Concat
concat
=
torch
.
cat
([
guess1
,
guess2
,
guess3
],
0
)
concat
=
torch
.
cat
([
guess1
,
guess2
,
guess3
],
0
)
...
@@ -169,13 +170,13 @@ def do_ssltrain(
...
@@ -169,13 +170,13 @@ def do_ssltrain(
rampup_length
,
rampup_length
,
):
):
"""
"""
Train model and save to disk.
Train
s
model
using semi-supervised learning
and save
s it
to disk.
Parameters
Parameters
----------
----------
model : :py:class:`torch.nn.Module`
model : :py:class:`torch.nn.Module`
Network (e.g.
DRIU, HED, UN
et)
Network (e.g.
driu, hed, un
et)
data_loader : :py:class:`torch.utils.data.DataLoader`
data_loader : :py:class:`torch.utils.data.DataLoader`
...
@@ -191,13 +192,14 @@ def do_ssltrain(
...
@@ -191,13 +192,14 @@ def do_ssltrain(
checkpointer
checkpointer
checkpoint_period : int
checkpoint_period : int
save a checkpoint every n epochs
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
device : str
device : str
device to use ``
'
cpu
'
`` or ``
'
cuda
'
``
device to use ``
'
cpu
'
`` or ``cuda
:0
``
arguments : dict
arguments : dict
start
e
nd end epochs
start
a
nd end epochs
output_folder : str
output_folder : str
output path
output path
...
@@ -206,15 +208,35 @@ def do_ssltrain(
...
@@ -206,15 +208,35 @@ def do_ssltrain(
rampup epochs
rampup epochs
"""
"""
logger
.
info
(
"
Start SSL training
"
)
logger
.
info
(
"
Start SSL training
"
)
start_epoch
=
arguments
[
"
epoch
"
]
start_epoch
=
arguments
[
"
epoch
"
]
max_epoch
=
arguments
[
"
max_epoch
"
]
max_epoch
=
arguments
[
"
max_epoch
"
]
# Logg to file
# Log to file
with
open
(
logfile_name
=
os
.
path
.
join
(
output_folder
,
"
trainlog.csv
"
)
os
.
path
.
join
(
output_folder
,
"
{}_trainlog.csv
"
.
format
(
model
.
name
)),
"
a+
"
,
1
)
as
outfile
:
if
arguments
[
"
epoch
"
]
==
0
and
os
.
path
.
exists
(
logfile_name
):
logger
.
info
(
f
"
Truncating
{
logfile_name
}
- training is restarting...
"
)
os
.
unlink
(
logfile_name
)
logfile_fields
=
(
"
epoch
"
,
"
total-time
"
,
"
eta
"
,
"
average-loss
"
,
"
median-loss
"
,
"
median-labelled-loss
"
,
"
median-unlabelled-loss
"
,
"
learning-rate
"
,
"
gpu-memory-megabytes
"
,
)
with
open
(
logfile_name
,
"
a+
"
,
newline
=
""
)
as
logfile
:
logwriter
=
csv
.
DictWriter
(
logfile
,
fieldnames
=
logfile_fields
)
if
arguments
[
"
epoch
"
]
==
0
:
logwriter
.
writeheader
()
for
state
in
optimizer
.
state
.
values
():
for
state
in
optimizer
.
state
.
values
():
for
k
,
v
in
state
.
items
():
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
if
isinstance
(
v
,
torch
.
Tensor
):
...
@@ -223,125 +245,96 @@ def do_ssltrain(
...
@@ -223,125 +245,96 @@ def do_ssltrain(
model
.
train
().
to
(
device
)
model
.
train
().
to
(
device
)
# Total training timer
# Total training timer
start_training_time
=
time
.
time
()
start_training_time
=
time
.
time
()
for
epoch
in
range
(
start_epoch
,
max_epoch
):
for
epoch
in
range
(
start_epoch
,
max_epoch
):
scheduler
.
step
()
scheduler
.
step
()
losses
=
SmoothedValue
(
len
(
data_loader
))
losses
=
SmoothedValue
(
len
(
data_loader
))
labeled_loss
=
SmoothedValue
(
len
(
data_loader
))
label
l
ed_loss
=
SmoothedValue
(
len
(
data_loader
))
unlabeled_loss
=
SmoothedValue
(
len
(
data_loader
))
unlabel
l
ed_loss
=
SmoothedValue
(
len
(
data_loader
))
epoch
=
epoch
+
1
epoch
=
epoch
+
1
arguments
[
"
epoch
"
]
=
epoch
arguments
[
"
epoch
"
]
=
epoch
# Epoch time
# Epoch time
start_epoch_time
=
time
.
time
()
start_epoch_time
=
time
.
time
()
for
samples
in
tqdm
(
data_loader
):
for
samples
in
tqdm
(
data_loader
,
desc
=
"
batches
"
,
leave
=
False
,
# labeled
disable
=
None
,):
# data forwarding on the existing network
# labelled
images
=
samples
[
1
].
to
(
device
)
images
=
samples
[
1
].
to
(
device
)
ground_truths
=
samples
[
2
].
to
(
device
)
ground_truths
=
samples
[
2
].
to
(
device
)
unlabeled_images
=
samples
[
4
].
to
(
device
)
unlabel
l
ed_images
=
samples
[
4
].
to
(
device
)
# labeled outputs
# label
l
ed outputs
outputs
=
model
(
images
)
outputs
=
model
(
images
)
unlabeled_outputs
=
model
(
unlabeled_images
)
unlabelled_outputs
=
model
(
unlabelled_images
)
# guessed unlabeled outputs
# guessed unlabelled outputs
unlabeled_ground_truths
=
guess_labels
(
unlabeled_images
,
model
)
unlabelled_ground_truths
=
guess_labels
(
unlabelled_images
,
model
)
# unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
# unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
# images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
# images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
# loss evaluation and learning (backward step)
ramp_up_factor
=
square_rampup
(
epoch
,
rampup_length
=
rampup_length
)
ramp_up_factor
=
square_rampup
(
epoch
,
rampup_length
=
rampup_length
)
loss
,
ll
,
ul
=
criterion
(
loss
,
ll
,
ul
=
criterion
(
outputs
,
outputs
,
ground_truths
,
ground_truths
,
unlabeled_outputs
,
unlabel
l
ed_outputs
,
unlabeled_ground_truths
,
unlabel
l
ed_ground_truths
,
ramp_up_factor
,
ramp_up_factor
,
)
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
losses
.
update
(
loss
)
losses
.
update
(
loss
)
labeled_loss
.
update
(
ll
)
label
l
ed_loss
.
update
(
ll
)
unlabeled_loss
.
update
(
ul
)
unlabel
l
ed_loss
.
update
(
ul
)
logger
.
debug
(
"
batch loss: {
}
"
.
format
(
loss
.
item
()
)
)
logger
.
debug
(
f
"
batch loss:
{
loss
.
item
()
}
"
)
if
epoch
%
checkpoint_period
==
0
:
if
checkpoint_period
and
(
epoch
%
checkpoint_period
==
0
)
:
checkpointer
.
save
(
"
model_{:03d}
"
.
format
(
epoch
)
,
**
arguments
)
checkpointer
.
save
(
f
"
model_
{
epoch
:
03
d
}
"
,
**
arguments
)
if
epoch
=
=
max_epoch
:
if
epoch
>
=
max_epoch
:
checkpointer
.
save
(
"
model_final
"
,
**
arguments
)
checkpointer
.
save
(
"
model_final
"
,
**
arguments
)
# computes ETA (estimated time-of-arrival; end of training) taking
# into consideration previous epoch performance
epoch_time
=
time
.
time
()
-
start_epoch_time
epoch_time
=
time
.
time
()
-
start_epoch_time
eta_seconds
=
epoch_time
*
(
max_epoch
-
epoch
)
eta_seconds
=
epoch_time
*
(
max_epoch
-
epoch
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
current_time
=
time
.
time
()
-
start_training_time
outfile
.
write
(
logdata
=
(
(
"
epoch
"
,
f
"
{
epoch
}
"
),
(
(
"
{epoch},
"
"
total-time
"
,
"
{avg_loss:.6f},
"
f
"
{
datetime
.
timedelta
(
seconds
=
int
(
current_time
))
}
"
,
"
{median_loss:.6f},
"
),
"
{median_labeled_loss},
"
(
"
eta
"
,
f
"
{
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
))
}
"
),
"
{median_unlabeled_loss},
"
(
"
average-loss
"
,
f
"
{
losses
.
avg
:
.
6
f
}
"
),
"
{lr:.6f},
"
(
"
median-loss
"
,
f
"
{
losses
.
median
:
.
6
f
}
"
),
"
{memory:.0f}
"
(
"
median-labelled-loss
"
,
f
"
{
labelled_loss
.
median
:
.
6
f
}
"
),
"
\n
"
(
"
median-unlabelled-loss
"
,
f
"
{
unlabelled_loss
.
median
:
.
6
f
}
"
),
).
format
(
(
"
learning-rate
"
,
f
"
{
optimizer
.
param_groups
[
0
][
'
lr
'
]
:
.
6
f
}
"
),
eta
=
eta_string
,
epoch
=
epoch
,
avg_loss
=
losses
.
avg
,
median_loss
=
losses
.
median
,
median_labeled_loss
=
labeled_loss
.
median
,
median_unlabeled_loss
=
unlabeled_loss
.
median
,
lr
=
optimizer
.
param_groups
[
0
][
"
lr
"
],
memory
=
(
torch
.
cuda
.
max_memory_allocated
()
/
1024.0
/
1024.0
)
if
torch
.
cuda
.
is_available
()
else
0.0
,
)
)
logger
.
info
(
(
(
"
eta: {eta},
"
"
gpu-memory-megabytes
"
,
"
epoch: {epoch},
"
f
"
{
torch
.
cuda
.
max_memory_allocated
()
/
(
1024.0
*
1024.0
)
}
"
"
avg. loss: {avg_loss:.6f},
"
"
median loss: {median_loss:.6f},
"
"
labeled loss: {median_labeled_loss},
"
"
unlabeled loss: {median_unlabeled_loss},
"
"
lr: {lr:.6f},
"
"
max mem: {memory:.0f}
"
).
format
(
eta
=
eta_string
,
epoch
=
epoch
,
avg_loss
=
losses
.
avg
,
median_loss
=
losses
.
median
,
median_labeled_loss
=
labeled_loss
.
median
,
median_unlabeled_loss
=
unlabeled_loss
.
median
,
lr
=
optimizer
.
param_groups
[
0
][
"
lr
"
],
memory
=
(
torch
.
cuda
.
max_memory_allocated
()
/
1024.0
/
1024.0
)
if
torch
.
cuda
.
is_available
()
if
torch
.
cuda
.
is_available
()
else
0.0
,
else
"
0.0
"
,
)
)
,
)
)
logwriter
.
writerow
(
dict
(
k
for
k
in
logdata
))
logger
.
info
(
"
|
"
.
join
([
f
"
{
k
}
:
{
v
}
"
for
(
k
,
v
)
in
logdata
]))
logger
.
info
(
"
End of training
"
)
total_training_time
=
time
.
time
()
-
start_training_time
total_training_time
=
time
.
time
()
-
start_training_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
total_training_time
))
logger
.
info
(
logger
.
info
(
"
Total training time: {} ({:.4f} s / epoch)
"
.
format
(
f
"
Total training time:
{
datetime
.
timedelta
(
seconds
=
total_training_time
)
}
(
{
(
total_training_time
/
max_epoch
)
:
.
4
f
}
s in average per epoch)
"
total_time_str
,
total_training_time
/
(
max_epoch
)
)
)
)
log_plot_file
=
os
.
path
.
join
(
output_folder
,
"
{}_trainlog.pdf
"
.
format
(
model
.
name
))
# plots a version of the CSV trainlog into a PDF
logdf
=
pd
.
read_csv
(
logdf
=
pd
.
read_csv
(
logfile_name
,
header
=
0
,
names
=
logfile_fields
)
os
.
path
.
join
(
output_folder
,
"
{}_trainlog.csv
"
.
format
(
model
.
name
)),
fig
=
loss_curve
(
logdf
,
title
=
"
Loss Evolution
"
)
header
=
None
,
figurefile_name
=
os
.
path
.
join
(
output_folder
,
"
trainlog.pdf
"
)
names
=
[
logger
.
info
(
f
"
Saving
{
figurefile_name
}
"
)
"
avg. loss
"
,
fig
.
savefig
(
figurefile_name
)
"
median loss
"
,
"
labeled loss
"
,
"
unlabeled loss
"
,
"
lr
"
,
"
max memory
"
,
],
)
fig
=
loss_curve
(
logdf
,
output_folder
)
logger
.
info
(
"
saving {}
"
.
format
(
log_plot_file
))
fig
.
savefig
(
log_plot_file
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment