Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Max Ehrlich
cdcnn
Commits
b19163d7
Verified
Commit
b19163d7
authored
Dec 04, 2018
by
Max Ehrlich
Browse files
Allow exploded convolutions to be precomputed for faster testing
parent
42ad5fb3
Changes
4
Hide whitespace changes
Inline
Side-by-side
experiments/model_conversion.py
View file @
b19163d7
...
...
@@ -33,6 +33,7 @@ for m in range(args.models):
print
(
'Convert model to JPEG'
)
jpeg_model
=
models
.
JpegResNetExact
(
spatial_model
).
to
(
device
)
jpeg_model
.
explode_all
()
print
(
'Test JPEG model'
)
jpeg_dataset
=
data
.
jpeg_dataset_map
[
args
.
dataset
](
args
.
batch_size
,
args
.
data
)
...
...
jpeg_layers/convolution.py
View file @
b19163d7
...
...
@@ -3,9 +3,9 @@ import opt_einsum as oe
import
numpy
as
np
class
Conv2d
Base
(
torch
.
nn
.
modules
.
Module
):
class
Conv2d
(
torch
.
nn
.
modules
.
Module
):
def
__init__
(
self
,
conv_spatial
,
J
):
super
(
Conv2d
Base
,
self
).
__init__
()
super
(
Conv2d
,
self
).
__init__
()
self
.
stride
=
conv_spatial
.
stride
self
.
weight
=
conv_spatial
.
weight
...
...
@@ -19,6 +19,8 @@ class Conv2dBase(torch.nn.modules.Module):
self
.
make_apply_op
()
self
.
jpeg_op
=
None
def
make_apply_op
(
self
):
input_shape
=
[
0
,
self
.
weight
.
shape
[
1
],
*
self
.
J_i
.
shape
[
0
:
3
]]
jpeg_op_shape
=
[
self
.
weight
.
shape
[
0
],
self
.
weight
.
shape
[
1
],
*
self
.
J_i
.
shape
[
0
:
3
],
*
self
.
J
.
shape
[
0
:
2
]]
...
...
@@ -27,7 +29,7 @@ class Conv2dBase(torch.nn.modules.Module):
self
.
apply_conv
.
evaluate_constants
(
backend
=
'torch'
)
def
_apply
(
self
,
fn
):
s
=
super
(
Conv2d
Base
,
self
).
_apply
(
fn
)
s
=
super
(
Conv2d
,
self
).
_apply
(
fn
)
s
.
make_apply_op
()
return
s
...
...
@@ -41,20 +43,13 @@ class Conv2dBase(torch.nn.modules.Module):
return
jpeg_op
class
Conv2dRT
(
Conv2dBase
):
def
__init__
(
self
,
conv_spatial
,
J
):
super
(
Conv2dRT
,
self
).
__init__
(
conv_spatial
,
J
)
def
forward
(
self
,
input
):
jpeg_op
=
self
.
explode
()
return
self
.
apply_conv
(
jpeg_op
,
input
,
backend
=
'torch'
)
class
Conv2dPre
(
Conv2dBase
):
def
__init__
(
self
,
conv_spatial
,
J
):
super
(
Conv2dPre
,
self
).
__init__
(
conv_spatial
,
J
)
def
explode_pre
(
self
):
self
.
jpeg_op
=
self
.
explode
()
def
forward
(
self
,
input
):
return
self
.
apply_conv
(
self
.
jpeg_op
,
input
,
backend
=
'torch'
)
if
self
.
jpeg_op
is
not
None
:
jpeg_op
=
self
.
jpeg_op
else
:
jpeg_op
=
self
.
explode
()
return
self
.
apply_conv
(
jpeg_op
,
input
,
backend
=
'torch'
)
models/blocks.py
View file @
b19163d7
...
...
@@ -48,8 +48,8 @@ class JpegResBlock(nn.Module):
J_down
=
(
J_out
[
0
],
J_in
[
1
])
self
.
conv1
=
jpeg_layers
.
Conv2d
RT
(
spatial_resblock
.
conv1
,
J_down
)
self
.
conv2
=
jpeg_layers
.
Conv2d
RT
(
spatial_resblock
.
conv2
,
J_out
)
self
.
conv1
=
jpeg_layers
.
Conv2d
(
spatial_resblock
.
conv1
,
J_down
)
self
.
conv2
=
jpeg_layers
.
Conv2d
(
spatial_resblock
.
conv2
,
J_out
)
self
.
bn1
=
jpeg_layers
.
BatchNorm
(
spatial_resblock
.
bn1
)
self
.
bn2
=
jpeg_layers
.
BatchNorm
(
spatial_resblock
.
bn2
)
...
...
@@ -57,10 +57,17 @@ class JpegResBlock(nn.Module):
self
.
relu
=
jpeg_layers
.
ReLU
(
n_freqs
=
n_freqs
)
if
spatial_resblock
.
downsampler
is
not
None
:
self
.
downsampler
=
jpeg_layers
.
Conv2d
RT
(
spatial_resblock
.
downsampler
,
J_down
)
self
.
downsampler
=
jpeg_layers
.
Conv2d
(
spatial_resblock
.
downsampler
,
J_down
)
else
:
self
.
downsampler
=
None
def
explode_all
(
self
):
self
.
conv1
.
explode_pre
()
self
.
conv2
.
explode_pre
()
if
self
.
downsampler
is
not
None
:
self
.
downsampler
.
explode_pre
()
def
forward
(
self
,
x
):
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
...
...
models/jpeg_model.py
View file @
b19163d7
...
...
@@ -19,6 +19,11 @@ class JpegResNet(nn.Module):
self
.
averagepooling
=
AvgPool
()
self
.
fc
=
spatial_model
.
fc
def
explode_all
(
self
):
self
.
block1
.
explode_all
()
self
.
block2
.
explode_all
()
self
.
block3
.
explode_all
()
def
forward
(
self
,
x
):
out
=
self
.
block1
(
x
)
out
=
self
.
block2
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment