Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Max Ehrlich
cdcnn
Commits
d4208bd3
Verified
Commit
d4208bd3
authored
Dec 13, 2018
by
Max Ehrlich
Browse files
Add relu training experiment
parent
5b29b727
Changes
1
Hide whitespace changes
Inline
Side-by-side
experiments/relu_training.py
0 → 100644
View file @
d4208bd3
import
models
import
torch
import
torch.optim
as
optim
import
argparse
import
data
device
=
torch
.
device
(
'cuda'
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--models'
,
type
=
int
,
help
=
'Number of models to use'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
help
=
'Number of epochs to train for'
)
parser
.
add_argument
(
'--dataset'
,
choices
=
data
.
spatial_dataset_map
.
keys
(),
help
=
'Dataset to use'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--data'
,
help
=
'Root folder for the dataset'
)
args
=
parser
.
parse_args
()
spatial_accuracies
=
0
asm_accuracies
=
torch
.
zeros
(
15
)
apx_accuracies
=
torch
.
zeros
(
15
)
for
m
in
range
(
args
.
models
):
print
(
'Train spatial model {}/{}'
.
format
(
m
+
1
,
args
.
models
))
spatial_dataset
=
data
.
spatial_dataset_map
[
args
.
dataset
](
args
.
batch_size
,
args
.
data
)
dataset_info
=
data
.
dataset_info
[
args
.
dataset
]
spatial_model
=
models
.
SpatialResNet
(
dataset_info
[
'channels'
],
dataset_info
[
'classes'
]).
to
(
device
)
optimizer
=
optim
.
Adam
(
spatial_model
.
parameters
())
for
e
in
range
(
args
.
epochs
):
models
.
train
(
spatial_model
,
device
,
spatial_dataset
[
0
],
optimizer
,
e
)
acc
=
models
.
test
(
spatial_model
,
device
,
spatial_dataset
[
1
])
spatial_accuracies
+=
acc
jpeg_dataset
=
data
.
jpeg_dataset_map
[
args
.
dataset
](
args
.
batch_size
,
args
.
data
)
for
f
in
range
(
15
):
print
(
'Train ASM JPEG with {} spatial frequencies'
.
format
(
f
))
jpeg_model
=
models
.
JpegResNet
(
models
.
SpatialResNet
(
dataset_info
[
'channels'
],
dataset_info
[
'classes'
]),
n_freqs
=
f
).
to
(
device
)
optimizer
=
optim
.
Adam
(
jpeg_model
.
parameters
())
for
e
in
range
(
args
.
epochs
):
models
.
train
(
jpeg_model
,
device
,
jpeg_dataset
[
0
],
optimizer
,
e
)
acc
=
models
.
test
(
jpeg_model
,
device
,
jpeg_dataset
[
1
])
asm_accuracies
[
f
]
+=
acc
print
(
'Train APX JPEG with {} spatial frequencies'
.
format
(
f
))
jpeg_model
=
models
.
JpegResNetApx
(
models
.
SpatialResNet
(
dataset_info
[
'channels'
],
dataset_info
[
'classes'
]),
n_freqs
=
f
).
to
(
device
)
optimizer
=
optim
.
Adam
(
jpeg_model
.
parameters
())
for
e
in
range
(
args
.
epochs
):
models
.
train
(
jpeg_model
,
device
,
jpeg_dataset
[
0
],
optimizer
,
e
)
acc
=
models
.
test
(
jpeg_model
,
device
,
jpeg_dataset
[
1
])
apx_accuracies
[
f
]
+=
acc
spatial_accuracies
/=
args
.
models
asm_accuracies
/=
args
.
models
apx_accuracies
/=
args
.
models
with
open
(
'{}_relu_training.csv'
.
format
(
args
.
dataset
),
'w'
)
as
f
:
f
.
write
(
'Spatial, ASM, APX
\n
'
)
for
i
in
range
(
15
):
f
.
write
(
'{}, {}, {}
\n
'
.
format
(
spatial_accuracies
,
asm_accuracies
[
i
],
apx_accuracies
[
i
]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment