Verified Commit 6834e9aa authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add initial layers

parent 702d6e34
import torch
class JpegAvgPool(torch.nn.modules.Module):
def __init__(self):
super(JpegAvgPool, self).__init__()
def forward(self, input):
result = torch.mean(input[:, :, :, :, 0].view(-1, input.shape[1], input.shape[2]*input.shape[3]), 2)
return result
import torch
class JpegBatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(JpegBatchNorm, self).__init__()
self.mean = bn.running_mean
self.var = bn.running_var
self.gamma = bn.weight
self.beta = bn.bias
self.gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
def forward(self, input):
input = input * self.gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + self.beta_final
return input
import torch
import opt_einsum as oe
import numpy as np
class Conv2d_base(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(Conv2d_base, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
self.padding = conv_spatial.padding
self.J = J
self.J_batched = self.J[1].contiguous().view(np.prod(self.J[1].shape[0:3]), 1, *self.J[1].shape[3:5])
input_shape = [0, self.weight.shape[1], *self.J[1].shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J[1].shape[0:3], *self.J[0].shape[0:2]]
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J[0], input_shape, constants=[1], optimize='optimal')
self.apply_conv.evaluate_constants(backend='torch')
def explode(self):
out_channels = self.weight.shape[0]
in_channels = self.weight.shape[1]
jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
jpeg_op = jpeg_op.permute(1, 0, 2, 3)
jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J[1].shape[0:3], *(np.array(self.J[1].shape[3:5]) // self.stride))
return jpeg_op
class Conv2d_a(Conv2d_base):
def __init__(self, conv_spatial, J):
super(Conv2d_a, self).__init__(conv_spatial, J)
def forward(self, input):
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class Conv2d_b(Conv2d_base):
def __init__(self, conv_spatial, J):
super(Conv2d_b, self).__init__(conv_spatial, J)
self.jpeg_op = self.explode()
def forward(self, input):
return self.apply_conv(self.jpeg_op, input, backend='torch')
import torch
import opt_einsum as oe
class JpegRelu(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(JpegRelu, self).__init__()
self.C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.hsm_op.evaluate_constants(backend='torch')
def annm(self, x):
appx_im = self.annm_op(x, backend='torch')
mask = torch.zeros_like(appx_im)
mask[appx_im >= 0] = 1
return mask
def half_spatial_mask(self, x, m):
return self.hsm_op(x, m, backend='torch')
def forward(self, input):
annm = self.annm(input)
out_comp = self.half_spatial_mask(input, annm)
return out_comp
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment