Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.pytorch
Commits
b3319f64
Commit
b3319f64
authored
Jul 23, 2018
by
Guillaume HEUSCH
Browse files
[test] added unit tests for DCGAN
parent
2221e9af
Pipeline
#22433
passed with stage
in 18 minutes and 14 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/pytorch/test/test.py
View file @
b3319f64
...
...
@@ -32,6 +32,21 @@ def test_architectures():
output
,
emdedding
=
net
.
forward
(
t
)
assert
output
.
shape
==
torch
.
Size
([
1
,
20
])
assert
emdedding
.
shape
==
torch
.
Size
([
1
,
512
])
# DCGAN
d
=
numpy
.
random
.
rand
(
1
,
3
,
64
,
64
).
astype
(
"float32"
)
t
=
torch
.
from_numpy
(
d
)
from
..architectures
import
DCGAN_discriminator
discriminator
=
DCGAN_discriminator
(
1
)
output
=
discriminator
.
forward
(
t
)
assert
output
.
shape
==
torch
.
Size
([
1
])
g
=
numpy
.
random
.
rand
(
1
,
100
,
1
,
1
).
astype
(
"float32"
)
t
=
torch
.
from_numpy
(
g
)
from
..architectures
import
DCGAN_generator
generator
=
DCGAN_generator
(
1
)
output
=
generator
.
forward
(
t
)
assert
output
.
shape
==
torch
.
Size
([
1
,
3
,
64
,
64
])
def
test_transforms
():
...
...
@@ -85,7 +100,7 @@ class DummyDataSet(Dataset):
return
sample
def
test_trainer
():
def
test_
CNN
trainer
():
from
..architectures
import
CNN8
net
=
CNN8
(
20
)
...
...
@@ -100,3 +115,37 @@ def test_trainer():
assert
os
.
path
.
isfile
(
'model_1_0.pth'
)
os
.
remove
(
'model_1_0.pth'
)
class
DummyDataSetGAN
(
Dataset
):
def
__init__
(
self
):
pass
def
__len__
(
self
):
return
100
def
__getitem__
(
self
,
idx
):
data
=
numpy
.
random
.
rand
(
3
,
64
,
64
).
astype
(
"float32"
)
sample
=
{
'image'
:
torch
.
from_numpy
(
data
)}
return
sample
def
test_DCGANtrainer
():
from
..architectures
import
DCGAN_generator
from
..architectures
import
DCGAN_discriminator
g
=
DCGAN_generator
(
1
)
d
=
DCGAN_discriminator
(
1
)
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
DummyDataSetGAN
(),
batch_size
=
32
,
shuffle
=
True
)
from
..trainers
import
DCGANTrainer
trainer
=
DCGANTrainer
(
g
,
d
,
batch_size
=
32
,
noise_dim
=
100
,
use_gpu
=
False
,
verbosity_level
=
2
)
trainer
.
train
(
dataloader
,
n_epochs
=
1
,
output_dir
=
'.'
)
import
os
assert
os
.
path
.
isfile
(
'fake_samples_epoch_000.png'
)
assert
os
.
path
.
isfile
(
'netD_epoch_0.pth'
)
assert
os
.
path
.
isfile
(
'netG_epoch_0.pth'
)
os
.
remove
(
'fake_samples_epoch_000.png'
)
os
.
remove
(
'netD_epoch_0.pth'
)
os
.
remove
(
'netG_epoch_0.pth'
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment