在 PyTorch 中实现 Vision Transformer 分类器

我决定研究常规变压器的扩展:视觉变压器。顾名思义,这种类型的 transformer 将图像作为输入,而不是单词序列。本博客将概述 Vision Transformer 的架构,并在 CIFAR100 数据集上实现基于 Vision Transformer 的分类器。

Vision Transformer 架构

顾名思义,视觉转换器是应用于图像数据的转换器。它们通常是基于编码器的,这意味着不需要自注意力掩码。每个查询向量都可以使用每个关键向量来生成注意力权重。与文本序列不同,图像天生不适合输入到 transformer 中。因此,主要考虑因素是决定如何将输入图像转换为标记。我们可以将图像的每个像素用作输入,但变压器所需的内存会随着输入标记的数量呈二次方增加,并且随着图像空间大小的增加,这很快就会变得不可行。

相反,我们将输入图像分割成大小为 PxPxC 的块。然后将这些贴片展平并形成大小为 Nx(P*P*C) 的矩阵。我们有 (H*W)/(P²) 个大小为 HxW 的图像的色块。

然后,每个补丁作为令牌馈送到 transformer 中,并将学习到的位置编码添加到令牌中。通常,数据集中的图像是固定大小的,因此我们有固定数量的 vision transformer 标记,这与基于文本的 transformer 形成鲜明对比。

Vision Transformer 的一般架构如下所示

一个有趣的方面是添加了一个随机初始化的可学习参数,称为类标记,它是输入的一部分。类令牌在通过网络的 transformer 层时可以累积来自序列中其他令牌的信息。它通过注意力机制来实现这一点,当充当查询时,可以通过注意力权重聚合来自所有补丁的信息。 

然后,通过将最后一层的类标记输入到线性层中来执行分类。如果我们要将最后一层的所有 Token 连接在一起,并将其输入到分类头中,这将为网络带来大量的参数,这将是非常低效且不可行的。拥有单独的类 token 可确保 vision transformer 将整个序列的一般表示学习到该 token 中,并且不会将最终输出偏向 sequence 中的单个 token。

我们可以像 GPT 模型一样,通过将 Vision Transformer 层堆叠在一起并在末端添加一个分类头来形成一个基于 Vision Transformer 的网络。由于 vision transformer 充当编码器,因此我们无需担心模型中的任何注意力掩码。

代码+人工智能入门到进阶攻略包+论文指导可关助工重浩:AI技术星球 发送:211 获取

Vision Transformer 实现

我们可以在 PyTorch 中实现一个简单的基于 Vision transformer 的模型,以对 CIFAR100 数据集中的图像进行分类。这是改编自出色的 Keras 视觉转换器指南。

首先,让我们设置初始全局变量并加载数据集。我们将图像大小调整为 72x72 并选择大小为 6 的补丁。这意味着我们将有 (72*72)/(6²) = 144 个补丁,每个补丁都成为一个标记。

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001
BATCH_SIZE = 32
NUM_EPOCHS = 100
IMAGE_SIZE = 72
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_LAYERS = 8
MLP_HEAD_UNITS = [2048, 1024]

train_transforms = torchvision.transforms.Compose(
  [
    torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    torchvision.transforms.RandomRotation(degrees=7),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std),
  ]
)

test_transforms = torchvision.transforms.Compose(
  [
    torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std),
  ]
)

train_dataset = torchvision.datasets.CIFAR100(
  root="./data", train=True, download=True, transform=train_transforms
)

valid_dataset = torchvision.datasets.CIFAR100(
  root="./data", train=False, download=True, transform=test_transforms
)

valid_set, test_set = torch.utils.data.random_split(
  valid_dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42)
)

trainloader = torch.utils.data.DataLoader(
  train_dataset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  pin_memory=True,
  num_workers=4,
  drop_last=True,
)
validloader = torch.utils.data.DataLoader(
  valid_set,
  batch_size=BATCH_SIZE,
  shuffle=False,
  pin_memory=True,
  num_workers=4,
  drop_last=True,
)
testloader = torch.utils.data.DataLoader(
  test_set,
  batch_size=BATCH_SIZE,
  shuffle=False,
  pin_memory=True,
  num_workers=4,
  drop_last=True,
)

Model Building Blocks 模型构建块 

我们将首先设置补丁创建层。这使用 PyTorch 展开层,该层从图像的空间维度生成补丁。然后,我们排列输出,使其采用 (Batch Size, Number of Patches, (P²)*C) 的形式,其中 P 是我们的补丁大小,C 是图像中的通道数。

class CreatePatchesLayer(torch.nn.Module):
  """Custom PyTorch Layer to Extract Patches from Images."""

  def __init__(
    self,
    patch_size: int,
    strides: int,
  ) -> None:
    """Init Variables."""
    super().__init__()
    self.unfold_layer = torch.nn.Unfold(
      kernel_size=patch_size, stride=strides
    )

  def forward(self, images: torch.Tensor) -> torch.Tensor:
    """Forward Pass to Create Patches."""
    patched_images = self.unfold_layer(images)
    return patched_images.permute((0, 2, 1))

我们可以在下面的代码中测试这个层。可以看出,给定一张图像,我们的补丁层将其拆分为 144 个单独的补丁。

batch_of_images = next(iter(trainloader))[0][0].unsqueeze(dim=0)

plt.figure(figsize=(4, 4))
image = torch.permute(batch_of_images[0], (1, 2, 0)).numpy()
plt.imshow(image)
plt.axis("off")
plt.savefig("img.png", bbox_inches="tight", pad_inches=0)
plt.clf()

patch_layer = CreatePatchesLayer(patch_size=PATCH_SIZE, strides=PATCH_SIZE)
patched_image = patch_layer(batch_of_images)
patched_image = patched_image.squeeze()

plt.figure(figsize=(4, 4))
for idx, patch in enumerate(patched_image):
  ax = plt.subplot(NUM_PATCHES, NUM_PATCHES, idx + 1)
  patch_img = torch.reshape(patch, (3, PATCH_SIZE, PATCH_SIZE))
  patch_img = torch.permute(patch_img, (1, 2, 0))
  plt.imshow(patch_img.numpy())
  plt.axis("off")
plt.savefig("patched_img.png", bbox_inches="tight", pad_inches=0)

 

Image vs Patched Image 图像 vs 修补后的映像 

然后,我们可以创建一个补丁嵌入层。这只是通过 PyTorch 嵌入层将学习的位置嵌入编码到补丁中,并将随机初始化的类标记连接到修补的数据上。

class PatchEmbeddingLayer(torch.nn.Module):
  """Positional Embedding Layer for Images of Patches."""

  def __init__(
    self,
    num_patches: int,
    batch_size: int,
    patch_size: int,
    embed_dim: int,
    device: torch.device,
  ) -> None:
    """Init Function."""
    super().__init__()
    self.num_patches = num_patches
    self.patch_size = patch_size
    self.position_emb = torch.nn.Embedding(
      num_embeddings=num_patches + 1, embedding_dim=embed_dim
    )
    self.projection_layer = torch.nn.Linear(
      patch_size * patch_size * 3, embed_dim
    )
    self.class_parameter = torch.nn.Parameter(
      torch.rand(batch_size, 1, embed_dim).to(device),
      requires_grad=True,
    )
    self.device = device

  def forward(self, patches: torch.Tensor) -> torch.Tensor:
    """Forward Pass."""
    positions = (
      torch.arange(start=0, end=self.num_patches + 1, step=1)
      .to(self.device)
      .unsqueeze(dim=0)
    )
    patches = self.projection_layer(patches)
    encoded_patches = torch.cat(
      (self.class_parameter, patches), dim=1
    ) + self.position_emb(positions)
    return encoded_patches

然后我们最终需要一个变压器层。 

简而言之,该层由层归一化、多头注意力块和最终的前馈网络组成。

class TransformerBlock(torch.nn.Module):
  """Transformer Block Layer."""

  def __init__(
    self,
    num_heads: int,
    key_dim: int,
    embed_dim: int,
    ff_dim: int,
    dropout_rate: float = 0.1,
  ) -> None:
    """Init variables and layers."""
    super().__init__()
    self.layer_norm_input = torch.nn.LayerNorm(
      normalized_shape=embed_dim, eps=1e-6
    )
    self.attn = torch.nn.MultiheadAttention(
      embed_dim=embed_dim,
      num_heads=num_heads,
      kdim=key_dim,
      vdim=key_dim,
      batch_first=True,
    )

    self.dropout_1 = torch.nn.Dropout(p=dropout_rate)
    self.layer_norm_1 = torch.nn.LayerNorm(
      normalized_shape=embed_dim, eps=1e-6
    )
    self.layer_norm_2 = torch.nn.LayerNorm(
      normalized_shape=embed_dim, eps=1e-6
    )
    self.ffn = create_mlp_block(
      input_features=embed_dim,
      output_features=[ff_dim, embed_dim],
      activation_function=torch.nn.GELU,
      dropout_rate=dropout_rate,
    )

  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    """Forward Pass."""
    layer_norm_inputs = self.layer_norm_input(inputs)
    attention_output, _ = self.attn(
      query=layer_norm_inputs,
      key=layer_norm_inputs,
      value=layer_norm_inputs,
    )
    attention_output = self.dropout_1(attention_output)
    out1 = self.layer_norm_1(inputs + attention_output)
    ffn_output = self.ffn(out1)
    output = self.layer_norm_2(out1 + ffn_output)
    return output

Network Implementation 网络实施 

然后,我们将这些层堆叠在一起以创建我们的网络。我们的网络首先由 patch 层和 patch embedding 层组成,然后是一堆 transformer 块。然后,我们获取类标记的最终嵌入,并通过一系列线性层输入它,以生成分类所需的 logits。生成的网络如下所示。

class ViTClassifierModel(torch.nn.Module):
  """ViT Model for Image Classification."""

  def __init__(
    self,
    num_transformer_layers: int,
    embed_dim: int,
    feed_forward_dim: int,
    num_heads: int,
    patch_size: int,
    num_patches: int,
    mlp_head_units: list[int],
    num_classes: int,
    batch_size: int,
    device: torch.device,
  ) -> None:
    """Init Function."""
    super().__init__()
    self.create_patch_layer = CreatePatchesLayer(patch_size, patch_size)
    self.patch_embedding_layer = PatchEmbeddingLayer(
      num_patches, batch_size, patch_size, embed_dim, device
    )
    self.transformer_layers = torch.nn.ModuleList()
    for _ in range(num_transformer_layers):
      self.transformer_layers.append(
        TransformerBlock(
          num_heads, embed_dim, embed_dim, feed_forward_dim
        )
      )

    self.mlp_block = create_mlp_block(
      input_features=embed_dim,
      output_features=mlp_head_units,
      activation_function=torch.nn.GELU,
        dropout_rate=0.5,
      )

    self.logits_layer = torch.nn.Linear(mlp_head_units[-1], num_classes)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward Pass."""
    x = self.create_patch_layer(x)
    x = self.patch_embedding_layer(x)
    for transformer_layer in self.transformer_layers:
      x = transformer_layer(x)
    x = x[:, 0]
    x = self.mlp_block(x)
    x = self.logits_layer(x)
    return x
----------------------------------------------------------------------------------------------------------------
                            Parent Layers       Layer (type)        Output Shape         Param #     Tr. Param #
================================================================================================================
    ViTClassifierModel/CreatePatchesLayer           Unfold-1      [32, 108, 144]               0               0
   ViTClassifierModel/PatchEmbeddingLayer           Linear-2       [32, 144, 64]           6,976           6,976
   ViTClassifierModel/PatchEmbeddingLayer        Embedding-3        [1, 145, 64]           9,280           9,280
      ViTClassifierModel/TransformerBlock        LayerNorm-4       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Dropout-5       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock        LayerNorm-6       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock           Linear-7      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock             GELU-8      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Dropout-9      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-10       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-11       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-12       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-13       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-14       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-15       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-16       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-17      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-18      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-19      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-20       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-21       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-22       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-23       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-24       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-25       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-26       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-27      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-28      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-29      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-30       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-31       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-32       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-33       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-34       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-35       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-36       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-37      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-38      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-39      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-40       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-41       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-42       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-43       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-44       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-45       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-46       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-47      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-48      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-49      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-50       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-51       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-52       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-53       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-54       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-55       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-56       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-57      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-58      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-59      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-60       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-61       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-62       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-63       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-64       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-65       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-66       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-67      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-68      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-69      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-70       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-71       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-72       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-73       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock       LayerNorm-74       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock         Dropout-75       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-76       [32, 145, 64]             128             128
      ViTClassifierModel/TransformerBlock          Linear-77      [32, 145, 128]           8,320           8,320
      ViTClassifierModel/TransformerBlock            GELU-78      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-79      [32, 145, 128]               0               0
      ViTClassifierModel/TransformerBlock          Linear-80       [32, 145, 64]           8,256           8,256
      ViTClassifierModel/TransformerBlock            GELU-81       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock         Dropout-82       [32, 145, 64]               0               0
      ViTClassifierModel/TransformerBlock       LayerNorm-83       [32, 145, 64]             128             128
                       ViTClassifierModel          Linear-84          [32, 2048]         133,120         133,120
                       ViTClassifierModel            GELU-85          [32, 2048]               0               0
                       ViTClassifierModel         Dropout-86          [32, 2048]               0               0
                       ViTClassifierModel          Linear-87          [32, 1024]       2,098,176       2,098,176
                       ViTClassifierModel            GELU-88          [32, 1024]               0               0
                       ViTClassifierModel         Dropout-89          [32, 1024]               0               0
                       ViTClassifierModel          Linear-90           [32, 100]         102,500         102,500
================================================================================================================
Total params: 2,485,732
Trainable params: 2,485,732
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------

分类和结果 

然后,以与常规卷积网络相同的方式进行训练。我们输入数据,计算标签的损失,然后计算我们的准确性。此的训练代码如下所示。

def train_network(
  model: torch.nn.Module,
  num_epochs: int,
  optimizer: torch.optim.Optimizer,
  loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
  trainloader: torch.utils.data.DataLoader,
  validloader: torch.utils.data.DataLoader,
  device: torch.device,
) -> None:
  """Train the Network."""
  print("Training Started")
  for epoch in range(1, num_epochs + 1):
    sys.stdout.flush()
    train_loss = []
    valid_loss = []
    num_examples_train = 0
    num_correct_train = 0
    num_examples_valid = 0
    num_correct_valid = 0
    num_correct_train_5 = 0
    num_correct_valid_5 = 0
    model.train()
    for batch in trainloader:
      optimizer.zero_grad()
      x = batch[0].to(device)
      y = batch[1].to(device)
      outputs = model(x)
      loss = loss_function(outputs, y)
      loss.backward()
      optimizer.step()
      train_loss.append(loss.item())
      num_corr, num_ex = calculate_accuracy(outputs, y)
      num_corr_5, _ = calculate_accuracy_top_5(outputs, y)
      num_examples_train += num_ex
      num_correct_train += num_corr
      num_correct_train_5 += num_corr_5

    model.eval()
    with torch.no_grad():
      for batch in validloader:
        images = batch[0].to(device)
        labels = batch[1].to(device)
        outputs = model(images)
        loss = loss_function(outputs, labels)
        valid_loss.append(loss.item())
        num_corr, num_ex = calculate_accuracy(outputs, labels)
        num_corr_5, _ = calculate_accuracy_top_5(outputs, labels)
        num_examples_valid += num_ex
        num_correct_valid += num_corr
        num_correct_valid_5 += num_corr_5

    print(
      f"Epoch: {epoch}, Training Loss: {np.mean(train_loss):.4f}, Validation Loss: {np.mean(valid_loss):.4f}, Training Accuracy: {num_correct_train/num_examples_train:.4f}, Validation Accuracy: {num_correct_valid/num_examples_valid:.4f}, Training Accuracy Top-5: {num_correct_train_5/num_examples_train:.4f}, Validation Accuracy Top-5: {num_correct_valid_5/num_examples_valid:.4f}"
    )

 我们使用 AdamW 优化器训练 100 个 epoch,学习率为 0.001,权重衰减为 0.0001。

从下图中可以看出,我们实现了大约 50% 的 top-1 验证准确率和 77% 的 top-5 验证准确率。

Loss vs Epoch Graph 亏损 vs Epoch 图

 

Accuracy vs Epoch Graps 准确性 vs Epoch Graps 

结论 

希望这篇博客能帮助揭开 Vision Transformer 架构和实现的神秘面纱。如您所见,它为卷积层提供了一种替代解决方案,可以应用于相对论小网络以成功对图像进行分类。


 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值