...
 
Commits (4)
from .codec import *
from .tensors import *
\ No newline at end of file
......@@ -6,7 +6,7 @@ C = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S()))
C_i = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S_i()))
def codec(image_size, block_size):
def codec(image_size, block_size=(8, 8)):
B_i = B(image_size, block_size)
J = torch.einsum('srxyij,ijk->srxyk', (B_i, C))
J_i = torch.einsum('srxyij,ijk->xyksr', (B_i, C_i))
......
......@@ -102,7 +102,7 @@ def S_i():
def B(shape, block_size):
blocks_shape = (shape[0] // block_size[0], shape[1] // block_size[1])
B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], device=device, dtype=torch.float)
B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], dtype=torch.float)
for s_x in range(shape[0]):
for s_y in range(shape[1]):
......@@ -111,6 +111,6 @@ def B(shape, block_size):
for i in range(block_size[0]):
for j in range(block_size[1]):
if x * block_size[0] + i == s_x and y * block_size[1] + j == s_y:
B_t[s_x, s_y, x, y, i ,j] = 1.0
B_t[s_x, s_y, x, y, i, j] = 1.0
return B_t
from .avgpool import *
from .batchnorm import *
from .convolution import *
from .relu import *
\ No newline at end of file
import torch
class JpegAvgPool(torch.nn.modules.Module):
class AvgPool(torch.nn.modules.Module):
def __init__(self):
super(JpegAvgPool, self).__init__()
super(AvgPool, self).__init__()
def forward(self, input):
result = torch.mean(input[:, :, :, :, 0].view(-1, input.shape[1], input.shape[2]*input.shape[3]), 2)
......
import torch
class JpegBatchNorm(torch.nn.modules.Module):
class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(JpegBatchNorm, self).__init__()
super(BatchNorm, self).__init__()
self.mean = bn.running_mean
self.var = bn.running_var
......
......@@ -3,9 +3,9 @@ import opt_einsum as oe
import numpy as np
class Conv2d_base(torch.nn.modules.Module):
class Conv2dBase(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(Conv2d_base, self).__init__()
super(Conv2dBase, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
......@@ -31,18 +31,18 @@ class Conv2d_base(torch.nn.modules.Module):
return jpeg_op
class Conv2d_a(Conv2d_base):
class Conv2d(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2d_a, self).__init__(conv_spatial, J)
super(Conv2d, self).__init__(conv_spatial, J)
def forward(self, input):
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class Conv2d_b(Conv2d_base):
class Conv2dPre(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2d_b, self).__init__(conv_spatial, J)
super(Conv2dPre, self).__init__(conv_spatial, J)
self.jpeg_op = self.explode()
def forward(self, input):
......
import torch
import opt_einsum as oe
from jpeg_codec.tensors import D_n, D, Z, S_i, S
from jpeg_codec import D_n, D, Z, S_i, S
class JpegRelu(torch.nn.modules.Module):
class ReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(JpegRelu, self).__init__()
super(ReLU, self).__init__()
self.C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
......
import torch.nn as nn
import jpeg_layers
class SpatialResBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample=True):
super(SpatialResBlock, self).__init__()
self.downsample = downsample
stride = 2 if downsample else 1
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsampler(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
class JpegResBlock(nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out):
super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1])
self.conv1 = jpeg_layers.Conv2d(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2d(spatial_resblock.conv2, J_out)
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsample:
self.downsample = True
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
else:
self.downsample = False
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsampler(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
import torch.nn as nn
from .blocks import JpegResBlock
from jpeg_layers import AvgPool
from jpeg_codec import codec
class JpegResNet(nn.Module):
def __init__(self, spatial_model, n_freqs):
super(JpegResNet, self).__init__()
J_32 = codec((32, 32))
J_16 = codec((16, 16))
J_8 = codec((8, 8))
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8)
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model):
super(JpegResNet, self).__init__(spatial_model, 14)
import torch.nn as nn
from .blocks import SpatialResBlock
class SpatialResNet(nn.Module):
def __init__(self):
super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False)
self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, 10)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out