我决定研究常规变压器的扩展:视觉变压器。顾名思义,这种类型的 transformer 将图像作为输入,而不是单词序列。本博客将概述 Vision Transformer 的架构,并在 CIFAR100 数据集上实现基于 Vision Transformer 的分类器。
Vision Transformer 架构
顾名思义,视觉转换器是应用于图像数据的转换器。它们通常是基于编码器的,这意味着不需要自注意力掩码。每个查询向量都可以使用每个关键向量来生成注意力权重。与文本序列不同,图像天生不适合输入到 transformer 中。因此,主要考虑因素是决定如何将输入图像转换为标记。我们可以将图像的每个像素用作输入,但变压器所需的内存会随着输入标记的数量呈二次方增加,并且随着图像空间大小的增加,这很快就会变得不可行。
相反,我们将输入图像分割成大小为 PxPxC 的块。然后将这些贴片展平并形成大小为 Nx(P*P*C) 的矩阵。我们有 (H*W)/(P²) 个大小为 HxW 的图像的色块。
然后,每个补丁作为令牌馈送到 transformer 中,并将学习到的位置编码添加到令牌中。通常,数据集中的图像是固定大小的,因此我们有固定数量的 vision transformer 标记,这与基于文本的 transformer 形成鲜明对比。
Vision Transformer 的一般架构如下所示
一个有趣的方面是添加了一个随机初始化的可学习参数,称为类标记,它是输入的一部分。类令牌在通过网络的 transformer 层时可以累积来自序列中其他令牌的信息。它通过注意力机制来实现这一点,当充当查询时,可以通过注意力权重聚合来自所有补丁的信息。
然后,通过将最后一层的类标记输入到线性层中来执行分类。如果我们要将最后一层的所有 Token 连接在一起,并将其输入到分类头中,这将为网络带来大量的参数,这将是非常低效且不可行的。拥有单独的类 token 可确保 vision transformer 将整个序列的一般表示学习到该 token 中,并且不会将最终输出偏向 sequence 中的单个 token。
我们可以像 GPT 模型一样,通过将 Vision Transformer 层堆叠在一起并在末端添加一个分类头来形成一个基于 Vision Transformer 的网络。由于 vision transformer 充当编码器,因此我们无需担心模型中的任何注意力掩码。
代码+人工智能入门到进阶攻略包+论文指导可关助工重浩:AI技术星球 发送:211 获取
Vision Transformer 实现
我们可以在 PyTorch 中实现一个简单的基于 Vision transformer 的模型,以对 CIFAR100 数据集中的图像进行分类。这是改编自出色的 Keras 视觉转换器指南。
首先,让我们设置初始全局变量并加载数据集。我们将图像大小调整为 72x72 并选择大小为 6 的补丁。这意味着我们将有 (72*72)/(6²) = 144 个补丁,每个补丁都成为一个标记。
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001
BATCH_SIZE = 32
NUM_EPOCHS = 100
IMAGE_SIZE = 72
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_LAYERS = 8
MLP_HEAD_UNITS = [2048, 1024]
train_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
torchvision.transforms.RandomRotation(degrees=7),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
test_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)
train_dataset = torchvision.datasets.CIFAR100(
root="./data", train=True, download=True, transform=train_transforms
)
valid_dataset = torchvision.datasets.CIFAR100(
root="./data", train=False, download=True, transform=test_transforms
)
valid_set, test_set = torch.utils.data.random_split(
valid_dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42)
)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=4,
drop_last=True,
)
validloader = torch.utils.data.DataLoader(
valid_set,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=4,
drop_last=True,
)
testloader = torch.utils.data.DataLoader(
test_set,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=4,
drop_last=True,
)
Model Building Blocks 模型构建块
我们将首先设置补丁创建层。这使用 PyTorch 展开层,该层从图像的空间维度生成补丁。然后,我们排列输出,使其采用 (Batch Size, Number of Patches, (P²)*C) 的形式,其中 P 是我们的补丁大小,C 是图像中的通道数。
class CreatePatchesLayer(torch.nn.Module):
"""Custom PyTorch Layer to Extract Patches from Images."""
def __init__(
self,
patch_size: int,
strides: int,
) -> None:
"""Init Variables."""
super().__init__()
self.unfold_layer = torch.nn.Unfold(
kernel_size=patch_size, stride=strides
)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""Forward Pass to Create Patches."""
patched_images = self.unfold_layer(images)
return patched_images.permute((0, 2, 1))
我们可以在下面的代码中测试这个层。可以看出,给定一张图像,我们的补丁层将其拆分为 144 个单独的补丁。
batch_of_images = next(iter(trainloader))[0][0].unsqueeze(dim=0)
plt.figure(figsize=(4, 4))
image = torch.permute(batch_of_images[0], (1, 2, 0)).numpy()
plt.imshow(image)
plt.axis("off")
plt.savefig("img.png", bbox_inches="tight", pad_inches=0)
plt.clf()
patch_layer = CreatePatchesLayer(patch_size=PATCH_SIZE, strides=PATCH_SIZE)
patched_image = patch_layer(batch_of_images)
patched_image = patched_image.squeeze()
plt.figure(figsize=(4, 4))
for idx, patch in enumerate(patched_image):
ax = plt.subplot(NUM_PATCHES, NUM_PATCHES, idx + 1)
patch_img = torch.reshape(patch, (3, PATCH_SIZE, PATCH_SIZE))
patch_img = torch.permute(patch_img, (1, 2, 0))
plt.imshow(patch_img.numpy())
plt.axis("off")
plt.savefig("patched_img.png", bbox_inches="tight", pad_inches=0)
Image vs Patched Image 图像 vs 修补后的映像
然后,我们可以创建一个补丁嵌入层。这只是通过 PyTorch 嵌入层将学习的位置嵌入编码到补丁中,并将随机初始化的类标记连接到修补的数据上。
class PatchEmbeddingLayer(torch.nn.Module):
"""Positional Embedding Layer for Images of Patches."""
def __init__(
self,
num_patches: int,
batch_size: int,
patch_size: int,
embed_dim: int,
device: torch.device,
) -> None:
"""Init Function."""
super().__init__()
self.num_patches = num_patches
self.patch_size = patch_size
self.position_emb = torch.nn.Embedding(
num_embeddings=num_patches + 1, embedding_dim=embed_dim
)
self.projection_layer = torch.nn.Linear(
patch_size * patch_size * 3, embed_dim
)
self.class_parameter = torch.nn.Parameter(
torch.rand(batch_size, 1, embed_dim).to(device),
requires_grad=True,
)
self.device = device
def forward(self, patches: torch.Tensor) -> torch.Tensor:
"""Forward Pass."""
positions = (
torch.arange(start=0, end=self.num_patches + 1, step=1)
.to(self.device)
.unsqueeze(dim=0)
)
patches = self.projection_layer(patches)
encoded_patches = torch.cat(
(self.class_parameter, patches), dim=1
) + self.position_emb(positions)
return encoded_patches
然后我们最终需要一个变压器层。
简而言之,该层由层归一化、多头注意力块和最终的前馈网络组成。
class TransformerBlock(torch.nn.Module):
"""Transformer Block Layer."""
def __init__(
self,
num_heads: int,
key_dim: int,
embed_dim: int,
ff_dim: int,
dropout_rate: float = 0.1,
) -> None:
"""Init variables and layers."""
super().__init__()
self.layer_norm_input = torch.nn.LayerNorm(
normalized_shape=embed_dim, eps=1e-6
)
self.attn = torch.nn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
kdim=key_dim,
vdim=key_dim,
batch_first=True,
)
self.dropout_1 = torch.nn.Dropout(p=dropout_rate)
self.layer_norm_1 = torch.nn.LayerNorm(
normalized_shape=embed_dim, eps=1e-6
)
self.layer_norm_2 = torch.nn.LayerNorm(
normalized_shape=embed_dim, eps=1e-6
)
self.ffn = create_mlp_block(
input_features=embed_dim,
output_features=[ff_dim, embed_dim],
activation_function=torch.nn.GELU,
dropout_rate=dropout_rate,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward Pass."""
layer_norm_inputs = self.layer_norm_input(inputs)
attention_output, _ = self.attn(
query=layer_norm_inputs,
key=layer_norm_inputs,
value=layer_norm_inputs,
)
attention_output = self.dropout_1(attention_output)
out1 = self.layer_norm_1(inputs + attention_output)
ffn_output = self.ffn(out1)
output = self.layer_norm_2(out1 + ffn_output)
return output
Network Implementation 网络实施
然后,我们将这些层堆叠在一起以创建我们的网络。我们的网络首先由 patch 层和 patch embedding 层组成,然后是一堆 transformer 块。然后,我们获取类标记的最终嵌入,并通过一系列线性层输入它,以生成分类所需的 logits。生成的网络如下所示。
class ViTClassifierModel(torch.nn.Module):
"""ViT Model for Image Classification."""
def __init__(
self,
num_transformer_layers: int,
embed_dim: int,
feed_forward_dim: int,
num_heads: int,
patch_size: int,
num_patches: int,
mlp_head_units: list[int],
num_classes: int,
batch_size: int,
device: torch.device,
) -> None:
"""Init Function."""
super().__init__()
self.create_patch_layer = CreatePatchesLayer(patch_size, patch_size)
self.patch_embedding_layer = PatchEmbeddingLayer(
num_patches, batch_size, patch_size, embed_dim, device
)
self.transformer_layers = torch.nn.ModuleList()
for _ in range(num_transformer_layers):
self.transformer_layers.append(
TransformerBlock(
num_heads, embed_dim, embed_dim, feed_forward_dim
)
)
self.mlp_block = create_mlp_block(
input_features=embed_dim,
output_features=mlp_head_units,
activation_function=torch.nn.GELU,
dropout_rate=0.5,
)
self.logits_layer = torch.nn.Linear(mlp_head_units[-1], num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward Pass."""
x = self.create_patch_layer(x)
x = self.patch_embedding_layer(x)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x)
x = x[:, 0]
x = self.mlp_block(x)
x = self.logits_layer(x)
return x
----------------------------------------------------------------------------------------------------------------
Parent Layers Layer (type) Output Shape Param # Tr. Param #
================================================================================================================
ViTClassifierModel/CreatePatchesLayer Unfold-1 [32, 108, 144] 0 0
ViTClassifierModel/PatchEmbeddingLayer Linear-2 [32, 144, 64] 6,976 6,976
ViTClassifierModel/PatchEmbeddingLayer Embedding-3 [1, 145, 64] 9,280 9,280
ViTClassifierModel/TransformerBlock LayerNorm-4 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-5 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-6 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-7 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-8 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-9 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-10 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-11 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-12 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-13 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-14 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-15 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-16 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-17 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-18 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-19 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-20 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-21 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-22 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-23 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-24 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-25 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-26 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-27 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-28 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-29 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-30 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-31 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-32 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-33 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-34 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-35 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-36 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-37 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-38 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-39 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-40 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-41 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-42 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-43 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-44 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-45 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-46 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-47 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-48 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-49 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-50 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-51 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-52 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-53 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-54 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-55 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-56 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-57 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-58 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-59 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-60 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-61 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-62 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-63 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-64 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-65 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-66 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-67 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-68 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-69 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-70 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-71 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-72 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-73 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-74 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-75 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-76 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-77 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-78 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-79 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-80 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-81 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-82 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-83 [32, 145, 64] 128 128
ViTClassifierModel Linear-84 [32, 2048] 133,120 133,120
ViTClassifierModel GELU-85 [32, 2048] 0 0
ViTClassifierModel Dropout-86 [32, 2048] 0 0
ViTClassifierModel Linear-87 [32, 1024] 2,098,176 2,098,176
ViTClassifierModel GELU-88 [32, 1024] 0 0
ViTClassifierModel Dropout-89 [32, 1024] 0 0
ViTClassifierModel Linear-90 [32, 100] 102,500 102,500
================================================================================================================
Total params: 2,485,732
Trainable params: 2,485,732
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------
分类和结果
然后,以与常规卷积网络相同的方式进行训练。我们输入数据,计算标签的损失,然后计算我们的准确性。此的训练代码如下所示。
def train_network(
model: torch.nn.Module,
num_epochs: int,
optimizer: torch.optim.Optimizer,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
trainloader: torch.utils.data.DataLoader,
validloader: torch.utils.data.DataLoader,
device: torch.device,
) -> None:
"""Train the Network."""
print("Training Started")
for epoch in range(1, num_epochs + 1):
sys.stdout.flush()
train_loss = []
valid_loss = []
num_examples_train = 0
num_correct_train = 0
num_examples_valid = 0
num_correct_valid = 0
num_correct_train_5 = 0
num_correct_valid_5 = 0
model.train()
for batch in trainloader:
optimizer.zero_grad()
x = batch[0].to(device)
y = batch[1].to(device)
outputs = model(x)
loss = loss_function(outputs, y)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
num_corr, num_ex = calculate_accuracy(outputs, y)
num_corr_5, _ = calculate_accuracy_top_5(outputs, y)
num_examples_train += num_ex
num_correct_train += num_corr
num_correct_train_5 += num_corr_5
model.eval()
with torch.no_grad():
for batch in validloader:
images = batch[0].to(device)
labels = batch[1].to(device)
outputs = model(images)
loss = loss_function(outputs, labels)
valid_loss.append(loss.item())
num_corr, num_ex = calculate_accuracy(outputs, labels)
num_corr_5, _ = calculate_accuracy_top_5(outputs, labels)
num_examples_valid += num_ex
num_correct_valid += num_corr
num_correct_valid_5 += num_corr_5
print(
f"Epoch: {epoch}, Training Loss: {np.mean(train_loss):.4f}, Validation Loss: {np.mean(valid_loss):.4f}, Training Accuracy: {num_correct_train/num_examples_train:.4f}, Validation Accuracy: {num_correct_valid/num_examples_valid:.4f}, Training Accuracy Top-5: {num_correct_train_5/num_examples_train:.4f}, Validation Accuracy Top-5: {num_correct_valid_5/num_examples_valid:.4f}"
)
我们使用 AdamW 优化器训练 100 个 epoch,学习率为 0.001,权重衰减为 0.0001。
从下图中可以看出,我们实现了大约 50% 的 top-1 验证准确率和 77% 的 top-5 验证准确率。

Accuracy vs Epoch Graps 准确性 vs Epoch Graps
结论
希望这篇博客能帮助揭开 Vision Transformer 架构和实现的神秘面纱。如您所见,它为卷积层提供了一种替代解决方案,可以应用于相对论小网络以成功对图像进行分类。