Verified Commit df707d83 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add relu layer using direct approximation

parent 46fb2b06
import torch
import opt_einsum as oe
from jpeg_codec import D_n, D, Z, S_i, S
from jpeg_codec import encode
class ReLU(torch.nn.modules.Module):
class ASMReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(ReLU, self).__init__()
super(ASMReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
......@@ -22,7 +23,7 @@ class ReLU(torch.nn.modules.Module):
def _apply(self, fn):
s = super(ReLU, self)._apply(fn)
s = super(ASMReLU, self)._apply(fn)
return s
......@@ -39,3 +40,37 @@ class ReLU(torch.nn.modules.Module):
annm = self.annm(input)
out_comp = self.half_spatial_mask(input, annm)
return out_comp
class APXReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(APXReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
C = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S()])
self.register_buffer('C', C)
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.compress_op = oe.contract_expression('ijk,tmxyij->tmxyk', self.C, [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
def _apply(self, fn):
s = super(APXReLU, self)._apply(fn)
return s
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
appx = self.compress_op(appx, backend='torch')
return appx
......@@ -45,7 +45,7 @@ class SpatialResBlock(nn.Module):
class JpegResBlock(nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out, relu_layer=jpeg_layers.ASMReLU):
super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1])
......@@ -56,7 +56,7 @@ class JpegResBlock(nn.Module):
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
self.relu = relu_layer(n_freqs=n_freqs)
if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
import torch.nn as nn
from .blocks import JpegResBlock
from jpeg_layers import AvgPool
from jpeg_layers import AvgPool, ASMReLU, APXReLU
from jpeg_codec import codec
class JpegResNet(nn.Module):
def __init__(self, spatial_model, n_freqs):
def __init__(self, spatial_model, n_freqs, relu_layer=ASMReLU):
super(JpegResNet, self).__init__()
J_32 = codec((32, 32))
J_16 = codec((16, 16))
J_8 = codec((8, 8))
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8)
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32, relu_layer=relu_layer)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16, relu_layer=relu_layer)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8, relu_layer=relu_layer)
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
......@@ -40,3 +40,8 @@ class JpegResNet(nn.Module):
class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model):
super(JpegResNetExact, self).__init__(spatial_model, 14)
class JpegResNetApx(JpegResNet):
def __init__(self, spatial_model, n_freqs):
super(JpegResNetApx, self).__init__(spatial_model, n_freqs, APXReLU)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment