如何更有效地学习视觉表征?


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

概述

算法原理

核心逻辑

效果演示

使用方式

参考文献


本文所有资源均可在该地址处获取。

概述

本文基于Pytorch复现论文 Towards Effective Visual Representations for Partial-Label Learning[1] 提出的偏标记学习方法。

在一般的监督学习任务中,数据集中的每个样例都拥有一个真实标签。然而,在偏标记学习(Partial Label Learning, PLL)任务中,每个样本仅拥有一个包含了多个标签的候选标签集,其中有且仅有一个是真实标签。Pico[2] 在一个统一框架中解决了 PLL 中的两个关键问题——表示学习和标签消歧,该框架包含一个对比学习模块以及一个基于类原型的标签消歧算法。此外,Pico 添加了一个原型分类器模块,以指导线性分类器的更新。虽然 Pico 在大部分基准测试上取得了最佳性能,但其仍存在两个问题:(1)对比学习模块所引入的伪标签中的噪声导致了显著的性能下降;(2)在训练的开始阶段,线性分类器比原型分类器更加准确。

为了解决上述两个问题,该论文提出了一个新的偏标记学习框架 Papi。Papi 删除了 Pico 中的对比学习模块,并采用与 Pico 相反的指导方向。具体来说,Papi 基于低维特征空间中样本特征与类原型之间的距离,通过 softmax 函数为每个样本生成一个类别相似度分布。然后,Papi将该分布与一个线性分类器的预测概率对齐。同时,该线性分类器进行自我教学,在每个学习阶段由之前的阶段指导。

算法原理

首先,对每一个样本 (xi,Yi)(xi​,Yi​),Papi 生成一个弱增强视图 aug1(xi)aug1​(xi​) 和一个强增强视图 aug2(xi)aug2​(xi​)。然后,每个视图都会被输入共享权重编码器网络并输出一对表征 vi1=f(aug1(xi))vi1​=f(aug1​(xi​)) 和 vi2=f(aug2(xi))vi2​=f(aug2​(xi​))。进一步地,表征 vi1vi1​ 会被输入线性分类器得到输出 ri=h(vi1)ri​=h(vi1​)。此外,表征都会被输入一个投影网络并被映射为 zil=g(vil)zil​=g(vil​)。

标签消歧

为了让线性分类器进行自我教学,Papi 根据线性分类器上一轮的输出 riri​ 计算本轮的目标:

uij={rij∑l∈Yirilif j∈Yi,0otherwiseuij​={∑l∈Yi​​ril​rij​​0​if j∈Yi​,otherwise​

考虑到训练刚开始时模型十分不稳定,Papi 使用移动平均策略更新学习目标:

pi=λpi+(1−λ)uipi​=λpi​+(1−λ)ui​

Papi 使用交叉熵损失函数计算损失的第一项:

Lclai=−∑j=1Kpij⋅log⁡rijLclai​=−j=1∑K​pij​⋅logrij​

原型对齐

Papi 利用 softmaxsoftmax 函数计算一个样本与类原型的相似度分布:

sij=exp⁡(zi⋅cj/τ)∑k=1Kexp⁡(zi⋅ck/τ)sij​=∑k=1K​exp(zi​⋅ck​/τ)exp(zi​⋅cj​/τ)​

对于类原型表征的计算,Papi 则使用移动平均策略进行更新:

ck=γck+(1−γ)zj, if j∈Ikck​=γck​+(1−γ)zj​, if j∈Ik​

为了利用线性分类器指导原型分类器,Papi 计算标签分布 pipi​ 与类原型相似度 sisi​ 之间的KL散度作为损失:

Lpai(pi,si)=∑l=12DKL(pi∣∣sil)Lpai​(pi​,si​)=l=1∑2​DKL​(pi​∣∣sil​)

为了学习到更稳健的特征,Papi使用 Mixup 来构建训练样本:

x^i=ϕxi+(1−ϕ)xm(i)x^i​=ϕxi​+(1−ϕ)xm(i)​

其中 ϕ∼Beta(α,α)ϕ∼Beta(α,α) 且 m(i)m(i) 为随机选取的。

由此得到了总损失的第二项:

Lalii=ϕLpai(pi,s^i)+(1−ϕ)Lpai(pm(i),s^i)Lalii​=ϕLpai​(pi​,s^i​)+(1−ϕ)Lpai​(pm​(i),s^i​)

模型更新:

总的损失函数如下损失:

Li=Lclai+φ(t)⋅LaliiLi=Lclai​+φ(t)⋅Lalii​

其中 φ(t)φ(t) 为由迭代轮数决定的动态平衡系数。

核心逻辑

具体的核心逻辑如下所示:

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import v2
from tqdm import tqdm
import json
import os
import augmentors as aug
import math


def proden_loss(probs, targets):
    sample_loss = -torch.sum(targets * torch.log(probs), dim = -1)
    mean_loss = torch.sum(sample_loss)/probs.shape[0]
    return sample_loss, mean_loss

class PapiTrainer:
    def __init__(self, configs):
        self.configs = configs
        self.base_lr = configs['learning_rate']
        self.lr_decay = configs['learning_rate_decay_rate']
        self.epochs = configs['epoch_count']

    def adjust_learning_rate(self, optimizer, epoch):
        lr_min = self.base_lr * (self.lr_decay ** 3)
        lr = lr_min + (self.base_lr - lr_min) * (1 + math.cos(math.pi * epoch / self.epochs)) / 2
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    def train(self):
        configs = self.configs
        # Papi超参数
        target_update_rate = configs['target_update_rate']
        similarity_temperature = configs['similarity_temperature']
        prototype_update_rate = configs['prototype_update_rate']
        mixup_factor = float(configs['mixup_factor'])
        mix_dist = torch.distributions.beta.Beta(torch.tensor([mixup_factor]), torch.tensor([mixup_factor]))
        loss_weight_update_epochs = configs['loss_weight_update_epochs']
        loss_weight = configs['loss_weight']
        low_dimension = configs['low_dimension']
        # 读取数据集
        dataset_path = configs['dataset_path']
        if not os.path.exists(dataset_path):
            os.mkdir(dataset_path)
        if configs['dataset'] == 'CIFAR-10':
            train_dataset = datasets.CIFAR10(dataset_path, train=True)
            test_dataset = datasets.CIFAR10(dataset_path, train=False)
            output_dimension = 10
            mean = [0.49139968, 0.48215827, 0.44653124] 
            std = [0.24703233, 0.24348505, 0.26158768]
        elif configs['dataset'] == 'CIFAR-100':
            train_dataset = datasets.CIFAR100(dataset_path, train=True)
            test_dataset = datasets.CIFAR100(dataset_path, train=False)
            output_dimension = 100
            mean = [0.50707516, 0.48654887, 0.44091784]
            std = [0.26733429, 0.25643846, 0.27615047]
        else:
            print('No such dataset')
            exit()
        test_dataloader = DataLoader(test_dataset, batch_size = configs['batch_size'], shuffle = False)
        class_num = output_dimension
        # 生成偏标记
        labels = np.array([sample['labels'] for sample in train_dataset])
        partial_labels = datasets.random_flip(labels, configs['partial_rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 模型输入的标准化
        normalize = v2.Normalize(mean, std)
        # 数据增强
        weak_aug = v2.Compose([
            v2.RandomHorizontalFlip(),
            v2.RandomCrop(32, 4, padding_mode='reflect'),
        ])
        strong_aug = aug.StrongAugment()
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        model = models.PapiModel(class_num, low_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning_rate']
        weight_decay = configs['weight_decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_decay = configs['learning_rate_decay_rate']
        # 类原型
        weak_prototypes = torch.zeros(class_num, low_dimension).to(device).unsqueeze(0)
        strong_prototypes = torch.zeros(class_num, low_dimension).to(device).unsqueeze(0)
        # 训练
        for epoch_id in range(configs['epoch_count']):
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch_size'], shuffle = True)
            loss_sum = 0
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                data = batch['data'].to(device)
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                labels = batch['labels'].to(device)
                # 计算并更新标签分布
                model.train()
                optimizer.zero_grad()
                weak_logits, weak_projection = model(normalize(weak_aug(data)), logit = True, proj=True)
                probs = F.softmax(weak_logits, dim=-1)
                with torch.no_grad():
                    update_rate = target_update_rate[0] + (target_update_rate[1] - target_update_rate[0]) * epoch_id / configs['epoch_count']
                    new_targets = F.normalize(probs.detach() * partial_labels, p=1, dim=-1) * (1 - update_rate) + targets * update_rate
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算消歧损失
                _, loss_cla = proden_loss(probs, new_targets)
                # 创建 Mixup 样本
                mix_rate = mix_dist.sample((data.shape[0],)).view(-1, 1, 1, 1).to(device)
                ids_mix = torch.randint(low=0, high=data.shape[0], size=(data.shape[0],))
                data_mix = data[ids_mix] * mix_rate + data * (1 - mix_rate)
                # 生成 Mixup 样本的投影
                weak_projection_mix = model(normalize(weak_aug(data_mix)), logit = False, proj=True)
                strong_projection_mix = model(normalize(strong_aug(data_mix)), logit = False, proj=True)
                # 生成原型相似度
                weak_similarity = F.softmax(torch.sum(weak_projection_mix.unsqueeze(1) * weak_prototypes / similarity_temperature, dim=-1), dim=-1)
                strong_similarity = F.softmax(torch.sum(strong_projection_mix.unsqueeze(1) * strong_prototypes / similarity_temperature, dim=-1), dim=-1)
                # 计算对齐损失
                weak_loss_ali = F.kl_div(weak_similarity, new_targets, reduction='batchmean')
                strong_loss_ali = F.kl_div(strong_similarity, new_targets, reduction='batchmean')
                loss_ali = weak_loss_ali + strong_loss_ali
                # 梯度反向传播并更新模型参数
                loss_weight_current = loss_weight * (epoch_id / loss_weight_update_epochs)
                loss_batch = loss_cla + loss_ali * loss_weight_current
                loss_batch.backward()
                optimizer.step()
                # 更新类原型
                model.eval()
                with torch.no_grad():
                    update_rate = prototype_update_rate[0] + (prototype_update_rate[1] - prototype_update_rate[0]) * epoch_id / configs['epoch_count']
                    strong_logits, strong_projection = model(normalize(strong_aug(data)), logit = True, proj=True)
                    weak_pseudo_labels = torch.argmax(weak_logits.detach() * partial_labels, dim=-1)
                    strong_pseudo_labels = torch.argmax(strong_logits.detach() * partial_labels, dim=-1)
                    for sample_id in range(data.shape[0]):
                        strong_label = strong_pseudo_labels[sample_id]
                        new_strong_prototype = strong_prototypes[0, strong_label] * update_rate + strong_projection[sample_id] * (1 - update_rate)
                        strong_prototypes[0, strong_label] = F.normalize(new_strong_prototype, p=2, dim=-1)
                        weak_label = weak_pseudo_labels[sample_id]
                        new_weak_prototype = weak_prototypes[0, weak_label] * update_rate + weak_projection[sample_id] * (1 - update_rate)
                        weak_prototypes[0, weak_label] = F.normalize(new_weak_prototype, p=2, dim=-1)
                # 记录
                loss_sum += loss_batch.item() * data.shape[0]
            loss_mean = loss_sum / len(train_dataset)
            print(f"Loss = {loss_mean}")
            loss_epoch.append(loss_mean)
            # 调整学习率
            self.adjust_learning_rate(optimizer, epoch_id)
            # 测试模型准确率
            model.eval()
            with torch.no_grad():
                shoot_sum = 0
                for batch in tqdm(test_dataloader, desc='Testing(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                    data = normalize(batch['data'].to(device))
                    labels = batch['labels'].to(device)
                    logits = model(data, logit = True, proj=False)
                    shoot = (torch.argmax(logits, dim = -1) == labels).to(torch.float32)
                    shoot_sum += torch.sum(shoot).item()
                acc = shoot_sum / len(test_dataset)
                acc_epoch.append(acc)
                print(f"ACC = {acc * 100}%")

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示

运行命令python main.py --partial_rate 0.01 之后,运行结果如图所示:

本文基于数据集 CIFAR-100 和神经网络 ResNet-18 进行试验。该数据集涉及 100 个类,包含训练样本 50000 条、测试样本 10000 条。

使用方式

  • 解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip Papi.zip
cd Papi

  • 先安装Anaconda,然后通过如下命令配置代码的运行环境:
conda create -n papi python=3.10.14
conda activate papi
pip install -r requirements.txt

  • 使用如下命令以在不同翻转概率下运行程序,:
python main.py --partial_rate 0.01

python main.py --partial_rate 0.05

python main.py --partial_rate 0.1

python main.py --partial_rate 0.2

  • 如果希望使用其他配置,可以修改由文件 data/configs/papi.json 存储的字典中各键所对应的值。

参考文献

[1] Xia S, Lv J, Xu N, et al. Towards effective visual representations for partial-label learning[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023: 15589-15598.

[2] Wang H, Xiao R, Li Y, et al. Pico: Contrastive label disambiguation for partial label learning[C]//International Conference on Learning Representations. 2021.

​​

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值