✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。
我是Srlua小谢,在这里我会分享我的知识和经验。🎥
希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮
记得先点赞👍后阅读哦~ 👏👏
📘📚 所属专栏:传知代码论文复现
欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙
目录
本文所有资源均可在该地址处获取。
本方法选自CVPR2022"The Devil Is in the Details: Window-based Attention for Image Compression",是一种性能极高的图像压缩方法,源码可参考附件文件,同时本文会详细介绍复现过程
背景
随着视觉应用的日益增多,图像压缩已经成为图像处理领域的一个重要研究课题。传统的图像压缩方法,如JPEG、JPEG2000等,主要依赖于手工设计的规则,虽然这些方法在一定程度上解决了图像存储和传输的问题,但在处理复杂纹理和细节方面存在局限性。近年来,基于卷积神经网络(CNN)的学习型图像压缩方法展示了优越的率失真性能。然而,CNN在捕捉局部冗余和非重复纹理方面仍存在不足,这限制了图像重建质量的进一步提升。
本文通过结合局部注意力机制和全局特征学习,提出了一种新的图像压缩方法,名为“Symmetrical TransFormer (STF)”框架,并证明了其在压缩图像时的优越性能。
相关工作
在图像压缩领域,学习型图像压缩方法近年来发展迅速,基于变分自编码器(VAE)的模型在率失真性能方面优于传统的有损压缩方法。VAE通过使用线性和非线性参数分析变换将图像映射到潜在编码空间,并结合熵估计模块精确预测潜在变量的分布,从而提高了压缩效率。自回归先验和高斯混合模型(GMM)进一步增强了熵估计模块的性能,但同时也增加了计算复杂性和时间开销。
注意力机制通过模拟生物观察的内部过程,分配更多的注意力资源到关键区域,从而获得更多细节并抑制其他无用信息。在图像压缩中,非局部注意力机制已经被证明可以通过生成隐式重要性掩码来引导潜在特征的自适应处理。尽管如此,非局部注意力机制并未改变CNN结构的全局感知特性,而Transformer架构通过利用注意力机制捕捉全局依赖关系,在图像分类和语义分割等任务中取得了优异的性能。
提出的方法
窗口注意力机制
大多数先前的方法使用基于全局感受野的注意力机制生成注意力掩码。然而,在图像压缩任务中,全局语义信息的作用不如局部空间邻近元素的相关性大。为此,本文提出了窗口注意力机制,具体方法如下:
- 局部窗口中的注意力计算:为了有效地建模并关注空间邻近元素,我们将特征图划分为非重叠的M×M窗口,在每个窗口内分别计算注意力掩码。
- 窗口注意力模块:WAM将非局部注意力模块替换为窗口注意力模块,以更好地关注高对比度区域。通过可视化注意力模块的效果,可以看到WAM在复杂区域(高对比度)分配了更多比特,而在简单区域(低对比度)分配了较少比特,从而在这些区域内保留更多的细节。
图1 窗口注意力的可视化结果
基于CNN的架构
在现有的CNN架构中插入窗口注意力模块,能够更合理地分配比特。该模块作为一个即插即用的组件,能够显著提升现有模型的率失真性能,同时计算开销可以忽略不计。具体架构如下图所示:
图2 基于CNN的架构
- 编码器与解码器:分别在编码器和解码器中插入WAM模块,帮助在不同区域内部合理分配比特。
- 卷积层与残差块:在编码器和解码器中使用卷积层和残差块(Residual Block, RB)以提高特征提取和重建能力。
基于Transformer的架构
受到Transformer架构在计算机视觉任务中成功的启发,本文设计了一个对称的Transformer框架,具体步骤如下:
重新思考Transformer的设计:由于目标是验证自注意力层和MLP是否能在学习型图像压缩任务中达到与原始CNN架构相当的性能,本文设计了一个全新的对称Transformer框架,在编码器和解码器中均不使用卷积层。主要难点在于:
- 传统工作大多基于CNN消除空间冗余并捕捉空间结构,直接将图像划分为patch可能导致patch内部的空间冗余。
- GDN是图像压缩中最常用的归一化和非线性激活函数,但在深层Transformer架构中不稳定,且与注意力机制不兼容。
- 计算大视野上的注意力图并非最佳。
-
编码器设计:将原始图像x划分为N大小的patch,通过线性嵌入层生成特征图fp。然后,将特征图重塑为序列fs,输入到Transformer块和patch合并层中,逐渐降低特征的分辨率并加倍特征通道数。
-
解码器设计:设计对称的解码器,由多个Transformer块、patch分割层和反嵌入层组成。patch分割层逐渐增加特征的分辨率并减半特征通道数,反嵌入层将特征图映射为重建图像x̂。
-
熵模型:为了更有效地预测潜在变量的概率分布,采用基于单高斯模型(SGM)的通道自回归熵模型。
图2 基于Transformer的架构
实验
实验设置
作者在CompresAI平台上实现了提出的CNN和Transformer架构,并在OpenImages数据集中随机选择30万张图像进行训练,图像随机裁剪为256×256大小。所有模型使用Adam优化器训练1.8M步,批次大小为16,初始学习率为1×10−4。
性能评估
在Kodak和CLIC专业验证集上评估模型性能。结果表明,提出的方法在PSNR和MS-SSIM指标上优于现有的最先进方法。具体而言,CNN和STF模型在优化MS-SSIM时在视觉质量方面表现显著提升。
视觉质量比较
在Kodak数据集中对重建图像进行视觉质量比较,结果显示,提出的方法在高对比度区域保留了更多细节。相比之下,传统方法在这些区域的细节表现较差。
核心代码解读
- 窗口注意力机制的实现类
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
- STF模型的前向传递过程
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
for i in range(self.num_layers):
layer = self.layers[i]
x, Wh, Ww = layer(x, Wh, Ww)
y = x
C = self.embed_dim * 8
y = y.view(-1, Wh, Ww, C).permute(0, 3, 1, 2).contiguous()
y_shape = y.shape[2:]
z = self.h_a(y)
_, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
latent_scales = self.h_scale_s(z_hat)
latent_means = self.h_mean_s(z_hat)
y_slices = y.chunk(self.num_slices, 1)
y_hat_slices = []
y_likelihood = []
for slice_index, y_slice in enumerate(y_slices):
support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
mean_support = torch.cat([latent_means] + support_slices, dim=1)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :, :y_shape[0], :y_shape[1]]
scale_support = torch.cat([latent_scales] + support_slices, dim=1)
scale = self.cc_scale_transforms[slice_index](scale_support)
scale = scale[:, :, :y_shape[0], :y_shape[1]]
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
y_likelihood.append(y_slice_likelihood)
y_hat_slice = ste_round(y_slice - mu) + mu
lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * torch.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
y_hat = torch.cat(y_hat_slices, dim=1)
y_likelihoods = torch.cat(y_likelihood, dim=1)
y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh*Ww, C)
for i in range(self.num_layers):
layer = self.syn_layers[i]
y_hat, Wh, Ww = layer(y_hat, Wh, Ww)
x_hat = self.end_conv(y_hat.view(-1, Wh, Ww, self.embed_dim).permute(0, 3, 1, 2).contiguous())
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}
- 使用的损失函数
class RateDistortionLoss(nn.Module):
"""Custom rate distortion loss with a Lagrangian parameter."""
def __init__(self, lmbda=1e-2):
super().__init__()
self.mse = nn.MSELoss()
self.lmbda = lmbda
def forward(self, output, target):
N, _, H, W = target.size()
out = {}
num_pixels = N * H * W
out["bpp_loss"] = sum(
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
for likelihoods in output["likelihoods"].values()
)
out["mse_loss"] = self.mse(output["x_hat"], target)
out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
return out
代码部署和复现流程
- 安装CompressAI
conda create -n compress python=3.7
conda activate compress
pip install compressai
pip install pybind11
git clone https://github.com/Googolxx/STF stf
cd stf
pip install -e .
pip install -e '.[dev]'
- 训练CNN模型
CUDA_VISIBLE_DEVICES=0,1 python train.py -d /path/to/image/dataset/ -e 1000 --batch-size 16 --save --save_path /path/to/save/ -m cnn --cuda --lambda 0.0035
e.g., CUDA_VISIBLE_DEVICES=0,1 python train.py -d openimages -e 1000 --batch-size 16 --save --save_path ckpt/cnn_0035.pth.tar -m cnn --cuda --lambda 0.0035
- 训练Transformer模型
CUDA_VISIBLE_DEVICES=0,1 python train.py -d /path/to/image/dataset/ -e 1000 --batch-size 16 --save --save_path /path/to/save/ -m stf --cuda --lambda 0.0035
- 评估模型
CUDA_VISIBLE_DEVICES=0 python -m compressai.utils.eval_model -d /path/to/image/folder/ -r /path/to/reconstruction/folder/ -a stf -p /path/to/checkpoint/ --cuda
CUDA_VISIBLE_DEVICES=0 python -m compressai.utils.eval_model -d /path/to/image/folder/ -r /path/to/reconstruction/folder/ -a cnn -p /path/to/checkpoint/ --cuda