Verified Commit df707d83 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add relu layer using direct approximation

parent 46fb2b06
import torch import torch
import opt_einsum as oe import opt_einsum as oe
from jpeg_codec import D_n, D, Z, S_i, S from jpeg_codec import D_n, D, Z, S_i, S
from jpeg_codec import encode
class ReLU(torch.nn.modules.Module): class ASMReLU(torch.nn.modules.Module):
def __init__(self, n_freqs): def __init__(self, n_freqs):
super(ReLU, self).__init__() super(ASMReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()]) C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n) self.register_buffer('C_n', C_n)
...@@ -22,7 +23,7 @@ class ReLU(torch.nn.modules.Module): ...@@ -22,7 +23,7 @@ class ReLU(torch.nn.modules.Module):
self.hsm_op.evaluate_constants(backend='torch') self.hsm_op.evaluate_constants(backend='torch')
def _apply(self, fn): def _apply(self, fn):
s = super(ReLU, self)._apply(fn) s = super(ASMReLU, self)._apply(fn)
s.make_masking_ops() s.make_masking_ops()
return s return s
...@@ -39,3 +40,37 @@ class ReLU(torch.nn.modules.Module): ...@@ -39,3 +40,37 @@ class ReLU(torch.nn.modules.Module):
annm = self.annm(input) annm = self.annm(input)
out_comp = self.half_spatial_mask(input, annm) out_comp = self.half_spatial_mask(input, annm)
return out_comp return out_comp
class APXReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(APXReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
C = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S()])
self.register_buffer('C', C)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
self.compress_op = oe.contract_expression('ijk,tmxyij->tmxyk', self.C, [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.compress_op.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(APXReLU, self)._apply(fn)
s.make_masking_ops()
return s
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
appx = self.compress_op(appx, backend='torch')
return appx
...@@ -45,7 +45,7 @@ class SpatialResBlock(nn.Module): ...@@ -45,7 +45,7 @@ class SpatialResBlock(nn.Module):
class JpegResBlock(nn.Module): class JpegResBlock(nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out): def __init__(self, spatial_resblock, n_freqs, J_in, J_out, relu_layer=jpeg_layers.ASMReLU):
super(JpegResBlock, self).__init__() super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1]) J_down = (J_out[0], J_in[1])
...@@ -56,7 +56,7 @@ class JpegResBlock(nn.Module): ...@@ -56,7 +56,7 @@ class JpegResBlock(nn.Module):
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1) self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2) self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs) self.relu = relu_layer(n_freqs=n_freqs)
if spatial_resblock.downsampler is not None: if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down) self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
......
import torch.nn as nn import torch.nn as nn
from .blocks import JpegResBlock from .blocks import JpegResBlock
from jpeg_layers import AvgPool from jpeg_layers import AvgPool, ASMReLU, APXReLU
from jpeg_codec import codec from jpeg_codec import codec
class JpegResNet(nn.Module): class JpegResNet(nn.Module):
def __init__(self, spatial_model, n_freqs): def __init__(self, spatial_model, n_freqs, relu_layer=ASMReLU):
super(JpegResNet, self).__init__() super(JpegResNet, self).__init__()
J_32 = codec((32, 32)) J_32 = codec((32, 32))
J_16 = codec((16, 16)) J_16 = codec((16, 16))
J_8 = codec((8, 8)) J_8 = codec((8, 8))
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32) self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32, relu_layer=relu_layer)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16) self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16, relu_layer=relu_layer)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8) self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8, relu_layer=relu_layer)
self.averagepooling = AvgPool() self.averagepooling = AvgPool()
self.fc = spatial_model.fc self.fc = spatial_model.fc
...@@ -40,3 +40,8 @@ class JpegResNet(nn.Module): ...@@ -40,3 +40,8 @@ class JpegResNet(nn.Module):
class JpegResNetExact(JpegResNet): class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model): def __init__(self, spatial_model):
super(JpegResNetExact, self).__init__(spatial_model, 14) super(JpegResNetExact, self).__init__(spatial_model, 14)
class JpegResNetApx(JpegResNet):
def __init__(self, spatial_model, n_freqs):
super(JpegResNetApx, self).__init__(spatial_model, n_freqs, APXReLU)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment