Verified Commit 702d6e34 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add codec operations

parent 1916c15a
import torch
import numpy as np
from .tensors import D, Z, S, S_i, B
C = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S()))
C_i = torch.einsum('ijab,abg,gk->ijk', (D(), Z(), S_i()))
def codec(image_size, block_size):
B_i = B(image_size, block_size)
J = torch.einsum('srxyij,ijk->srxyk', (B_i, C))
J_i = torch.einsum('srxyij,ijk->xyksr', (B_i, C_i))
return J, J_i
def encode(batch, block_size=(8, 8), device=None):
J, _ = codec(batch.shape[2:], block_size)
if device is not None:
batch = batch.to(device)
J = J.to(device)
jpeg_batch = torch.einsum('srxyk,ncsr->ncxyk', (J, batch))
return jpeg_batch
def decode(batch, device=None):
block_size = np.sqrt(batch.shape[4])
image_size = (batch.shape[2] * block_size, batch.shape[3] * block_size)
_, J_i = codec(image_size, (block_size, block_size))
if device is not None:
batch = batch.to(device)
J_i = J_i.to(device)
spatial_batch = torch.einsum('xyksr,ncxyk->ncsr', (J_i, batch))
return spatial_batch
import torch
import numpy as np
def A(alpha):
if alpha == 0:
return 1.0 / np.sqrt(2)
else:
return 1
def D():
D_t = torch.zeros([8, 8, 8, 8], dtype=torch.float)
for i in range(8):
for j in range(8):
for alpha in range(8):
for beta in range(8):
scale_a = A(alpha)
scale_b = A(beta)
coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
return D_t
def D_n(n_freqs):
D_t = torch.zeros([8, 8, 8, 8], dtype=torch.float)
for i in range(8):
for j in range(8):
for alpha in range(8):
for beta in range(8):
if alpha + beta <= n_freqs:
scale_a = A(alpha)
scale_b = A(beta)
coeff_x = np.cos(((2 * i + 1) * alpha * np.pi) / 16)
coeff_y = np.cos(((2 * j + 1) * beta * np.pi) / 16)
D_t[i, j, alpha, beta] = 0.25 * scale_a * scale_b * coeff_x * coeff_y
return D_t
def Z():
z = np.array([[ 0, 1, 5, 6, 14, 15, 27, 28],
[ 2, 4, 7, 13, 16, 26, 29, 42],
[ 3, 8, 12, 17, 25, 30, 41, 43],
[ 9, 11, 18, 24, 31, 40, 44, 53],
[10, 19, 23, 32, 39, 45, 52, 54],
[20, 22, 33, 38, 46, 51, 55, 60],
[21, 34, 37, 47, 50, 56, 59, 61],
[35, 36, 48, 49, 57, 58, 62, 63]], dtype=float)
Z_t = torch.zeros([8, 8, 64], dtype=torch.float)
for alpha in range(8):
for beta in range(8):
for gamma in range(64):
if z[alpha, beta] == gamma:
Z_t[alpha, beta, gamma] = 1
return Z_t
def S():
q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
S_t = torch.zeros([64, 64], dtype=torch.float)
for gamma in range(64):
for k in range(64):
if gamma == k:
S_t[gamma, k] = 1.0 / q[k]
return S_t
def S_i():
q = np.array([ 8, 16, 16, 19, 16, 19, 22, 22, 22, 22, 22, 22, 26, 24, 26, 27,
27, 27, 26, 26, 26, 26, 27, 27, 27, 29, 29, 29, 34, 34, 34, 29,
29, 29, 27, 27, 29, 29, 32, 32, 34, 34, 37, 38, 37, 35, 35, 34,
35, 38, 38, 40, 40, 40, 48, 48, 46, 46, 56, 56, 58, 69, 69, 83], dtype=float)
S_t = torch.zeros([64, 64], dtype=torch.float)
for gamma in range(64):
for k in range(64):
if gamma == k:
S_t[gamma, k] = q[k]
return S_t
def B(shape, block_size):
blocks_shape = (shape[0] // block_size[0], shape[1] // block_size[1])
B_t = torch.zeros([shape[0], shape[1], blocks_shape[0], blocks_shape[1], block_size[0], block_size[1]], device=device, dtype=torch.float)
for s_x in range(shape[0]):
for s_y in range(shape[1]):
for x in range(blocks_shape[0]):
for y in range(blocks_shape[1]):
for i in range(block_size[0]):
for j in range(block_size[1]):
if x * block_size[0] + i == s_x and y * block_size[1] + j == s_y:
B_t[s_x, s_y, x, y, i ,j] = 1.0
return B_t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment