【LLIE专题】 CLIP-LIT代码解读

本文是对CLIP-LIT技术的代码解读,原文解读请看CLIP-LIT。

1、原文概要

CLIP 模型凭借其在大规模图像 - 文本数据上学到的知识,在零样本分类等任务中表现出色,具有强大的泛化能力。CLIP-LIT 方法正是受到启发,将 CLIP 模型引入低层次视觉任务,即逆光图像增强。其核心在于设计一个迭代提示学习框架,分为两个主要阶段。第一阶段是提示词初始化与初始增强网络训练;第二阶段是提示词精炼与增强网络迭代微调。其网络结构如下:
在这里插入图片描述

2、代码结构

代码整体结构如下
CLIP-LIT/
├── README.md # 项目说明文档
├── clip_score.py # 定义 CLIP 得分计算函数和损失函数
├── dataloader_images.py # 图像数据加载器
├── dataloader_prompt_add.py # 提示添加数据加载器
├── dataloader_prompt_margin.py # 提示边际数据加载器
├── model_small.py # 定义图像增强模型
├── requirements.txt # 项目依赖文件
├── test.py # 测试脚本
├── test_function.py # 测试相关函数
├── train.py # 训练脚本
├── CLIP/ # CLIP 模型相关代码
│ ├── init.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── model.py
│ └── simple_tokenizer.py
├── pretrained_models/ # 预训练模型文件
│ ├── enhancement_model.pth
│ └── init_pretrained_models/
└── input/ # 测试图像输入文件夹
├── 110_51093502029_o.jpg
├── 111_51093504404_o.jpg
├── …

项目的代码结构清晰,主要包括以下几个部分:
数据加载器:dataloader_images.py、dataloader_prompt_add.py 和 dataloader_prompt_margin.py 负责加载训练和测试图像数据,并进行数据增强。
模型定义:model_small.py 定义了图像增强模型,采用了 UNet 结构。
CLIP 得分和损失计算:clip_score.py 实现了 CLIP 得分计算和自定义损失函数,为模型训练提供了必要的指标。
训练脚本:train.py 包含了模型训练的主要逻辑,包括模型加载、数据加载、损失函数定义、优化器设置和训练循环。
测试脚本:test.py 用于对输入图像进行增强测试,加载预训练模型并保存增强后的图像。

3 、核心代码模块

1. CLIP 模块

CLIP/clip.py
  • 功能概述:该文件主要负责 CLIP 模型的下载、加载以及图像的预处理。
  • 核心函数
    • _download(url, root):从指定的 URL 下载模型文件到指定的根目录。
    • _transform(n_px):定义了图像预处理的转换操作,包括调整大小、裁剪、转换为 RGB 格式、转换为张量和归一化。
    • available_models():返回可用的 CLIP 模型名称列表。
    • load(name, device, jit, download_root):加载指定名称的 CLIP 模型,并返回模型和对应的图像预处理函数。
    • tokenize(texts, context_length, truncate):对输入的文本进行分词处理,返回分词后的张量。
# 示例:加载 CLIP 模型
model, preprocess = load("ViT-B/32", device="cuda")
CLIP/simple_tokenizer.py
  • 功能概述:实现了一个简单的分词器,用于将文本转换为分词后的编码。
  • 核心类
    • SimpleTokenizer:分词器类,包含 bpeencodedecode 等方法。
      • bpe(token):对单个词进行字节对编码(BPE)处理。
      • encode(text):对输入的文本进行分词编码。
      • decode(tokens):将分词编码解码为原始文本。
# 示例:使用分词器进行编码
tokenizer = SimpleTokenizer()
tokens = tokenizer.encode("This is a test.")
CLIP/model.py
  • 功能概述:定义了 CLIP 模型的具体结构,包括视觉模型和文本模型。
  • 核心类
    • ModifiedResNet:修改后的 ResNet 模型,用于视觉特征提取。
    • VisionTransformer:基于 Transformer 的视觉模型。
    • CLIP:CLIP 模型的主类,结合了视觉模型和文本模型。
# 示例:创建 CLIP 模型
clip_model = CLIP(
    embed_dim=512,
    image_resolution=224,
    vision_layers=(3, 4, 6, 3),
    vision_width=64,
    vision_patch_size=16,
    context_length=77,
    vocab_size=49408,
    transformer_width=512,
    transformer_heads=8,
    transformer_layers=12
)

2. 训练模块 train.py

  • 功能概述:该文件实现了模型的训练过程,包括数据加载、模型初始化、损失函数定义和梯度更新等。
  • 核心步骤
    1. 加载 CLIP 模型:使用 clip.load 加载预训练的 CLIP 模型,并冻结其参数。
    2. 定义文本编码器和提示模块
      • TextEncoder:用于编码文本提示。
      • Prompts:用于学习文本提示的嵌入表示。
    3. 初始化增强模型:使用 model_small.UNet_emb_oneBranch_symmetry_noreflect 初始化图像增强模型。
    4. 加载数据集:使用 dataloader_sharp.lowlight_loaderdataloader_prompt_margin.lowlight_loader 等加载训练数据。
    5. 定义损失函数:包括 L_clipL_clip_MSEfour_margin_loss 等。
    6. 训练循环:根据不同的训练阶段,交替训练增强模型和提示模块。
# 示例:训练过程中的部分代码
for epoch in range(config.num_epochs):
    if total_iteration < config.num_clip_pretrained_iters:
        # 预训练阶段
        for iteration, item in enumerate(train_loader):
            img_lowlight, img_lowlight_path = item
            img_lowlight = img_lowlight.cuda()
            light_map = U_net(img_lowlight)
            final = torch.clamp(((img_lowlight) / (light_map + 0.000000001)), 0, 1)
            cliploss = 16 * 20 * L_clip(final, text_features)
            clip_MSEloss = 25 * L_clip_MSE(final, img_lowlight, [1.0, 1.0, 1.0, 1.0, 0.5])
            loss = cliploss + clip_MSEloss
            train_optimizer.zero_grad()
            loss.backward()
            train_optimizer.step()

3. 测试模块 test.py

  • 功能概述:该文件实现了模型的测试过程,对输入的低光照图像进行增强处理,并保存增强后的图像。
  • 核心步骤
    1. 解析命令行参数:指定输入图像文件夹、输出图像文件夹和预训练模型路径。
    2. 加载增强模型:使用 model_small.UNet_emb_oneBranch_symmetry 加载预训练的增强模型。
    3. 图像增强处理:对输入的低光照图像进行预处理,通过增强模型生成光照图,然后将原始图像除以光照图得到增强后的图像。
    4. 保存增强后的图像:使用 torchvision.utils.save_image 保存增强后的图像。
# 示例:图像增强处理
def lowlight(image_path):
    data_lowlight = Image.open(image_path)
    data_lowlight = (np.asarray(data_lowlight) / 255.0)
    data_lowlight = torch.from_numpy(data_lowlight).float().cuda()
    data_lowlight = data_lowlight.permute(2, 0, 1)
    data_lowlight = data_lowlight.unsqueeze(0)
    light_map = U_net(data_lowlight)
    enhanced_image = torch.clamp((data_lowlight / light_map), 0, 1)
    result_path = args.output + os.path.basename(image_path).replace('.jpg', '.png')
    torchvision.utils.save_image(enhanced_image, result_path)

4. 损失函数模块 clip_score.py

  • 功能概述:该文件定义了与 CLIP 相关的损失函数,用于模型的训练。
  • 核心类和函数
    • get_clip_score(tensor, words):计算输入图像张量与指定文本的 CLIP 得分。
    • L_clip:CLIP 损失类,用于计算图像与文本的相似度损失。
    • Prompts:提示模块类,用于学习文本提示的嵌入表示。
    • L_clip_from_feature:根据文本特征计算 CLIP 损失。
    • L_clip_MSE:计算图像特征的均方误差损失。
    • four_margin_loss:四边缘损失类,用于处理多组图像之间的相似度关系。
# 示例:使用 L_clip 损失函数
l_clip = L_clip()
loss = l_clip(x, light=True)

5. 网络结构模块 model_small.py

model_small.py 是 CLIP-LIT 项目中定义轻量级图像增强模型的核心文件,基于 U-Net 架构设计,用于从逆光图像中预测光照图,从而实现图像增强。

主模型类 UNet_emb_oneBranch_symmetry
功能概述

基于 U-Net 的对称编码器-解码器架构,输入逆光图像,输出单通道光照图(范围为 (0, 1]),通过原图除以光照图实现增强。

网络结构
输入图像 (3通道)
├── 编码路径(下采样):
│   ├── 层1: DoubleConv(3→64)
│   ├── 层2: Down(64→128)
│   ├── 层3: Down(128→256)
│   └── 层4: Down(256→512)
├── 瓶颈层: DoubleConv(512→1024)
└── 解码路径(上采样):
    ├── 层5: Up(1024→512) + 层3特征拼接
    ├── 层6: Up(512→256) + 层2特征拼接
    ├── 层7: Up(256→128) + 层1特征拼接
    └── 输出层: Conv2d(128→1)(Sigmoid激活)
核心代码
class UNet_emb_oneBranch_symmetry(nn.Module):
    def __init__(self, n_channels=3, bilinear=True):
        super(UNet_emb_oneBranch_symmetry, self).__init__()
        self.n_channels = n_channels
        self.bilinear = bilinear
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.bottleneck = DoubleConv(512, 1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.outc = nn.Conv2d(128, 1, kernel_size=1)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.bottleneck(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        logits = self.outc(x)
        # 光照图通过 Sigmoid 限制在 (0, 1]
        return torch.sigmoid(logits) + 1e-8  # 加极小值避免除零
关键细节
  1. 输入输出

    • 输入:形状为 (B, 3, H, W) 的逆光图像(RGB,归一化到 [0, 1])。
    • 输出:形状为 (B, 1, H, W) 的光照图,数值范围接近 (0, 1]。
  2. 特征融合

    • 解码路径中,上采样后的特征图与编码路径对应层级的特征图通过 torch.cat 拼接,融合浅层细节与深层语义。
  3. 激活函数

    • 输出层使用 Sigmoid 激活,确保光照图为正数,避免原图除以光照图时出现数值不稳定。

4、详细代码注释

下面是相关脚本作用以及代码详细注释
train.py:训练脚本,包含模型初始化、数据集加载、损失函数定义和优化器设置等步骤。通过迭代训练,不断更新模型和提示参数。代码每一行都加了注释,代码如下。

from math import sqrt
import os
# os.environ['CUDA_VISIBLE_DEVICES']='0,1'  # 可用于指定使用的GPU设备,这里注释掉了
import torch
import torch.nn as nn
from torch.nn import functional as F
# import torchvision
import torch.optim
import argparse

import dataloader_prompt_margin
import dataloader_prompt_add
import dataloader_images as dataloader_sharp 

import model_small

import numpy as np

from test_function import inference

import clip_score
import random
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter
import clip

import pyiqa
import shutil

# 定义训练任务的名称
task_name="train0"
# 创建一个SummaryWriter对象,用于将训练过程中的数据写入TensorBoard,方便可视化
writer = SummaryWriter('./'+task_name+"/"+'tensorboard_'+task_name)

# 定义保存训练脚本的目标路径
dstpath="./"+task_name+"/"+"train_scripts"
# 如果目标路径不存在,则创建该路径
if not os.path.exists(dstpath):
    os.makedirs(dstpath)
# 将当前的train.py脚本复制到目标路径下
shutil.copy("train.py",dstpath)

# 检查是否有可用的GPU,如果有则使用GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 打印当前使用的设备
print(device)
# 加载CLIP模型,使用ViT-B/32架构,将模型下载到指定的路径
model, preprocess = clip.load("ViT-B/32", device=torch.device("cpu"), download_root="./clip_model/")#ViT-B/32
# 将CLIP模型移动到指定设备上
model.to(device)
# 冻结CLIP模型的所有参数,不进行训练
for para in model.parameters():
    para.requires_grad = False

# 定义一个文本编码器类,用于对文本进行编码
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        # 从CLIP模型中获取transformer层
        self.transformer = clip_model.transformer
        # 从CLIP模型中获取位置嵌入层
        self.positional_embedding = clip_model.positional_embedding
        # 从CLIP模型中获取最终的归一化层
        self.ln_final = clip_model.ln_final
        # 从CLIP模型中获取文本投影层
        self.text_projection = clip_model.text_projection
        # 获取CLIP模型的数据类型
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        # 将提示词和位置嵌入相加
        x = prompts + self.positional_embedding.type(self.dtype)
        # 调整张量的维度顺序
        x = x.permute(1, 0, 2)  # NLD -> LND
        # 通过transformer层进行特征提取
        x = self.transformer(x)
        # 调整张量的维度顺序
        x = x.permute(1, 0, 2)  # LND -> NLD
        # 通过最终的归一化层
        x = self.ln_final(x).type(self.dtype)
        # 根据tokenized_prompts的最大值位置选择对应的特征,并进行投影
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        
        return x

# 定义一个提示词类,用于生成和处理提示词
class Prompts(nn.Module):
    def __init__(self,initials=None):
        super(Prompts,self).__init__()
        # 打印初始提示词
        print("The initial prompts are:",initials)
        # 初始化文本编码器
        self.text_encoder = TextEncoder(model)
        if isinstance(initials,list):
            # 如果初始提示词是列表,则将其进行分词并转换为嵌入向量
            text = clip.tokenize(initials).cuda()
            # print(text)
            self.embedding_prompt = nn.Parameter(model.token_embedding(text).requires_grad_()).cuda()
        elif isinstance(initials,str):
            # 如果初始提示词是字符串,则认为是提示词的保存路径,加载提示词
            prompt_path=initials
            state_dict = torch.load(prompt_path)
            # 创建一个新的有序字典,去除模块名前缀
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # 将加载的提示词嵌入向量作为可训练参数
            self.embedding_prompt=nn.Parameter(new_state_dict['embedding_prompt']).cuda()
            self.embedding_prompt.requires_grad = True
        else:
            # 如果没有提供初始提示词,则随机初始化提示词嵌入向量
            self.embedding_prompt=torch.nn.init.xavier_normal_(nn.Parameter(model.token_embedding([" ".join(["X"]*config.length_prompt)," ".join(["X"]*config.length_prompt)]).requires_grad_())).cuda()

    def forward(self,tensor,flag=1):
        # 对提示词进行分词
        tokenized_prompts= torch.cat([clip.tokenize(p) for p in [" ".join(["X"]*config.length_prompt)]])
        # 通过文本编码器获取文本特征
        text_features = self.text_encoder(self.embedding_prompt,tokenized_prompts)
        
        for i in range(tensor.shape[0]):
            # 获取图像特征
            image_features=tensor[i]
            # 对文本特征进行归一化
            nor=torch.norm(text_features,dim=-1, keepdim=True)
            if flag==0:
                # 计算图像特征和文本特征的相似度,不进行softmax操作
                similarity = (100.0 * image_features @ (text_features/nor).T)#.softmax(dim=-1)
                if(i==0):
                    probs=similarity
                else:
                    probs=torch.cat([probs,similarity],dim=0)
            else:
                # 计算图像特征和文本特征的相似度,并进行softmax操作
                similarity = (100.0 * image_features @ (text_features/nor).T).softmax(dim=-1)#/nor
                if(i==0):
                    probs=similarity[:,0]
                else:
                    probs=torch.cat([probs,similarity[:,0]],dim=0)
        return probs

# 定义一个权重初始化函数,用于对卷积层和批量归一化层的权重进行初始化
def weights_init(m):
    classname = m.__class__.__name__ 
    if classname.find('Conv') != -1:
        # 对卷积层的权重进行正态分布初始化
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        # 对批量归一化层的权重进行正态分布初始化
        m.weight.data.normal_(1.0, 0.02)
        # 将批量归一化层的偏置置为0
        m.bias.data.fill_(0)

# 定义一个随机裁剪函数,用于对图像进行随机裁剪
def random_crop(img):
    b,c,h,w=img.shape
    # 随机选择裁剪的起始高度
    hs=random.randint(0,h-224)
    # 随机选择裁剪的起始宽度
    hw=random.randint(0,w-224)
    return img[:,:,hs:hs+224,hw:hw+224]

# 定义训练函数
def train(config):
    
    # 加载图像增强模型
    U_net=model_small.UNet_emb_oneBranch_symmetry_noreflect(3,1).cuda()
  
    # 初始化图像质量评估指标(PSNR)
    iqa_metric = pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
    
    # 如果需要加载预训练的提示词
    if config.load_pretrain_prompt == True:
        # 初始化提示词模型,并加载预训练的提示词
        learn_prompt=Prompts(config.prompt_pretrain_dir).cuda()
        # 保存预训练的提示词
        torch.save(learn_prompt.state_dict(), config.prompt_snapshots_folder + "pretrained_prompt" + '.pth')
    else:
        if config.num_clip_pretrained_iters < 3000:
            # 如果从无预训练开始训练,提示需要足够的迭代次数
            print("WARNING: For training from scratch, num_clip_pretrained_iters should not lower than 3000 iterations!\nAutomatically reset num_clip_pretrained_iters to 8000 iterations...")
            config.num_clip_pretrained_iters=8000
        # 初始化提示词模型
        learn_prompt=Prompts([" ".join(["X"]*(config.length_prompt))," ".join(["X"]*(config.length_prompt))]).cuda()
    # 使用DataParallel对提示词模型进行并行训练
    learn_prompt =  torch.nn.DataParallel(learn_prompt)
    # 对图像增强模型的权重进行初始化
    U_net.apply(weights_init)
    
    # 如果需要加载预训练的图像增强模型
    if config.load_pretrain == True:
        print("The load_pretrain is True, thus num_reconstruction_iters is automatically set to 0.")
        config.num_reconstruction_iters=0
        state_dict = torch.load(config.pretrain_dir)
        # 创建一个新的有序字典,去除模块名前缀
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        # 加载预训练的图像增强模型权重
        U_net.load_state_dict(new_state_dict)
        #U_net.load_state_dict(torch.load(config.pretrain_dir))
        # 保存预训练的图像增强模型
        torch.save(U_net.state_dict(), config.train_snapshots_folder + "pretrained_network" + '.pth')
    else:
        if config.num_reconstruction_iters<200:
            # 如果从无预训练开始训练,提示需要足够的迭代次数
            print("WARNING: For training from scratch, num_reconstruction_iters should not lower than 200 iterations!\nAutomatically reset num_reconstruction_iters to 1000 iterations...")
            config.num_reconstruction_iters=1000
    # 使用DataParallel对图像增强模型进行并行训练
    U_net= torch.nn.DataParallel(U_net)
    
    # 加载训练数据集
    train_dataset = dataloader_sharp.lowlight_loader(config.lowlight_images_path,config.overlight_images_path)    #dataloader
    # 创建训练数据加载器
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
    
    # 加载提示词训练数据集
    prompt_train_dataset = dataloader_prompt_margin.lowlight_loader(config.lowlight_images_path,config.normallight_images_path)#,config.overlight_images_path)        
    # 创建提示词训练数据加载器
    prompt_train_loader = torch.utils.data.DataLoader(prompt_train_dataset, batch_size=config.prompt_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
    # 加载另一个提示词训练数据集
    prompt_train_dataset_1 = dataloader_prompt_add.lowlight_loader(config.lowlight_images_path,config.normallight_images_path)
    # 创建另一个提示词训练数据加载器
    prompt_train_loader_1 = torch.utils.data.DataLoader(prompt_train_dataset_1, batch_size=config.prompt_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
    
    # 初始化文本编码器
    text_encoder = TextEncoder(model)
    # 初始化CLIP损失函数
    L_clip = clip_score.L_clip_from_feature()
    # 初始化CLIP MSE损失函数
    L_clip_MSE = clip_score.L_clip_MSE()
    # 初始化边际损失函数
    L_margin_loss = clip_score.four_margin_loss(0.9,0.2)#0.9,0.2
    
    # 定义图像增强模型的优化器
    train_optimizer = torch.optim.Adam(U_net.parameters(), lr=config.train_lr, weight_decay=config.weight_decay)
    # reconsturction_train_optimizer = torch.optim.Adam(U_net.parameters(), lr=config.reconstruction_train_lr, weight_decay=config.weight_decay)
    # 定义提示词模型的优化器
    prompt_optimizer = torch.optim.Adam(learn_prompt.parameters(), lr=config.prompt_lr, weight_decay=config.weight_decay)

    # 初始化训练参数
    U_net.train()
    total_iteration=0
    cur_iteration=0
    max_score_psnr=-10000
    pr_last_few_iter=0
    score_psnr=[0]*30
    semi_path=['','']
    pr_semi_path=0
    #last_iteration=0
    best_model=U_net
    best_prompt=learn_prompt
    min_prompt_loss=100
    best_prompt_iter=0
    best_model_iter=0
    rounds=0
    reconstruction_iter=0
    reinit_flag=0
    
    # 开始训练循环
    for epoch in range(config.num_epochs):
        if total_iteration<config.num_clip_pretrained_iters:
            train_thre=0
            total_thre=config.num_clip_pretrained_iters
        elif total_iteration<config.num_reconstruction_iters+config.num_clip_pretrained_iters:
            train_thre=config.num_reconstruction_iters
            total_thre=config.num_reconstruction_iters
        elif cur_iteration==0:
            train_thre=2100#800#2100#800#200
            total_thre=3100#2800#3100#1200#500
            print("cur using prompt from: iteration ", best_prompt_iter)
            print("cur using best model from: iteration ", best_model_iter)
        if cur_iteration+1<=train_thre: 
            if cur_iteration==0:
                learn_prompt=best_prompt
            # 获取提示词嵌入向量
            embedding_prompt=learn_prompt.module.embedding_prompt
            # 冻结提示词嵌入向量的梯度
            embedding_prompt.requires_grad = False
            # 对提示词进行分词
            tokenized_prompts= torch.cat([clip.tokenize(p) for p in [" ".join(["X"]*config.length_prompt)]])
            # 通过文本编码器获取文本特征
            text_features = text_encoder(embedding_prompt,tokenized_prompts)
            # 冻结提示词模型的所有参数
            for name, param in learn_prompt.named_parameters():
                param.requires_grad_(False)

            for iteration, item in enumerate(train_loader): 
        
                img_lowlight ,img_lowlight_path=item
                
                img_lowlight = img_lowlight.cuda()

                # 通过图像增强模型生成光照图
                light_map  = U_net(img_lowlight)
                # 根据光照图对低光照图像进行增强处理
                final=torch.clamp(((img_lowlight) /(light_map+0.000000001)),0,1)
               
                # 计算CLIP损失
                cliploss=16*20*L_clip(final, text_features)
                # 计算CLIP MSE损失
                clip_MSEloss = 25*L_clip_MSE(final, img_lowlight,[1.0,1.0,1.0,1.0,0.5])

model_small.py:文件主要定义了几个用于图像光照增强的神经网络模型,同时包含了一些辅助模块,如金字塔池化模块和残差块。加了注释代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个图像光照增强模型,采用单分支对称结构且不使用反射填充
class UNet_emb_oneBranch_symmetry_noreflect(nn.Module):

    # 初始化函数,in_channels为输入通道数,默认为3;out_channels为输出通道数,默认为3;bias为是否使用偏置,默认为False
    def __init__(self, in_channels=3, out_channels=3,bias=False):
        # 调用父类的初始化函数
        super(UNet_emb_oneBranch_symmetry_noreflect, self).__init__()

        # 第一个卷积层,将输入通道数转换为32通道
        self.cond1 = nn.Conv2d(in_channels,32,3,1,1,bias=True) 
        # 最后一个卷积层,将32通道转换为输出通道数
        self.cond_add1 = nn.Conv2d(32,out_channels,3,1,1,bias=True)           

        # 中间的卷积层,用于特征维度转换
        self.condx = nn.Conv2d(32,64,3,1,1,bias=True) 
        self.condy = nn.Conv2d(64,32,3,1,1,bias=True) 

        # 定义ReLU激活函数,inplace为True表示直接在原张量上进行操作
        self.relu = nn.ReLU(inplace=True)
        # 定义LeakyReLU激活函数,负斜率为0.2
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # 定义多个残差块,用于特征提取和保留
        self.ResidualBlock1=ResidualBlock(32,32)
        self.ResidualBlock2=ResidualBlock(32,32)
        self.ResidualBlock3=ResidualBlock(64,64)
        self.ResidualBlock4=ResidualBlock(64,64)
        self.ResidualBlock5=ResidualBlock(32,32)
        self.ResidualBlock6=ResidualBlock(32,32)

        # 定义金字塔池化模块,用于多尺度特征融合
        self.PPM1 = PPM1(32,8,bins=(1,2,3,6))

    # 权重初始化函数,对卷积层和反卷积层的权重进行初始化
    def _init_weights(self):
        # 遍历模型的所有模块
        for m in self.modules():
            # 如果是反卷积层或卷积层
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
                # 使用正态分布初始化权重,均值为0,标准差为0.02
                m.weight.data.normal_(0.0, 0.02)
                # 可以选择将偏置初始化为0
                #nn.init.zeros_(m.bias.data)

    # 前向传播函数
    def forward(self, x):
        # 通过第一个卷积层并使用LeakyReLU激活
        light_conv1=self.lrelu(self.cond1(x))
        # 通过第一个残差块
        res1=self.ResidualBlock1(light_conv1)
        
        # 通过第二个残差块
        res2=self.ResidualBlock2(res1)
        # 通过金字塔池化模块
        res2=self.PPM1(res2)
        # 通过中间卷积层进行特征维度转换
        res2=self.condx(res2)
        
        # 通过第三个残差块
        res3=self.ResidualBlock3(res2)
        # 通过第四个残差块
        res4=self.ResidualBlock4(res3)

        # 通过中间卷积层进行特征维度转换
        res4=self.condy(res4)
        # 通过第五个残差块
        res5=self.ResidualBlock5(res4)
        
        # 通过第六个残差块
        res6=self.ResidualBlock6(res5)
        
        # 通过最后一个卷积层并使用ReLU激活,得到光照图
        light_map=self.relu(self.cond_add1(res6))
 
        return light_map

# 定义一个图像光照增强模型,采用单分支对称结构且使用反射填充
class UNet_emb_oneBranch_symmetry(nn.Module):
    
    # 初始化函数,in_channels为输入通道数,默认为3;out_channels为输出通道数,默认为3;bias为是否使用偏置,默认为False
    def __init__(self, in_channels=3, out_channels=3,bias=False):
        # 调用父类的初始化函数
        super(UNet_emb_oneBranch_symmetry, self).__init__()

        # 第一个卷积层,将输入通道数转换为32通道,使用反射填充
        self.cond1 = nn.Conv2d(in_channels,32,3,1,1,bias=True,padding_mode='reflect') 
        # 最后一个卷积层,将32通道转换为输出通道数,使用反射填充
        self.cond_add1 = nn.Conv2d(32,out_channels,3,1,1,bias=True,padding_mode='reflect')           

        # 中间的卷积层,用于特征维度转换,使用反射填充
        self.condx = nn.Conv2d(32,64,3,1,1,bias=True,padding_mode='reflect') 
        self.condy = nn.Conv2d(64,32,3,1,1,bias=True,padding_mode='reflect') 

        # 定义ReLU激活函数,inplace为True表示直接在原张量上进行操作
        self.relu = nn.ReLU(inplace=True)
        # 定义LeakyReLU激活函数,负斜率为0.2
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # 定义多个残差块,用于特征提取和保留
        self.ResidualBlock1=ResidualBlock(32,32)
        self.ResidualBlock2=ResidualBlock(32,32)
        self.ResidualBlock3=ResidualBlock(64,64)
        self.ResidualBlock4=ResidualBlock(64,64)
        self.ResidualBlock5=ResidualBlock(32,32)
        self.ResidualBlock6=ResidualBlock(32,32)

        # 定义金字塔池化模块,用于多尺度特征融合
        self.PPM1 = PPM1(32,8,bins=(1,2,3,6))

    # 权重初始化函数,对卷积层和反卷积层的权重进行初始化
    def _init_weights(self):
        # 遍历模型的所有模块
        for m in self.modules():
            # 如果是反卷积层或卷积层
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
                # 使用正态分布初始化权重,均值为0,标准差为0.02
                m.weight.data.normal_(0.0, 0.02)
                # 可以选择将偏置初始化为0
                #nn.init.zeros_(m.bias.data)

    # 前向传播函数
    def forward(self, x):
        # 通过第一个卷积层并使用LeakyReLU激活
        light_conv1=self.lrelu(self.cond1(x))
        # 通过第一个残差块
        res1=self.ResidualBlock1(light_conv1)
        
        # 通过第二个残差块
        res2=self.ResidualBlock2(res1)
        # 通过金字塔池化模块
        res2=self.PPM1(res2)
        # 通过中间卷积层进行特征维度转换
        res2=self.condx(res2)
        
        # 通过第三个残差块
        res3=self.ResidualBlock3(res2)
        # 通过第四个残差块
        res4=self.ResidualBlock4(res3)
        # 通过中间卷积层进行特征维度转换
        res4=self.condy(res4)
        
        # 通过第五个残差块
        res5=self.ResidualBlock5(res4)
        # 通过第六个残差块
        res6=self.ResidualBlock6(res5)

        # 通过最后一个卷积层并使用ReLU激活,得到光照图
        light_map=self.relu(self.cond_add1(res6))

        return light_map

# 定义金字塔池化模块
class PPM1(nn.Module):
    # 初始化函数,in_dim为输入维度,reduction_dim为降维后的维度,bins为池化的尺度
    def __init__(self, in_dim, reduction_dim, bins):
        # 调用父类的初始化函数
        super(PPM1, self).__init__()
        # 定义特征提取模块列表
        self.features = []
        # 遍历不同的池化尺度
        for bin in bins:
            # 每个尺度下的特征提取模块,包括自适应平均池化、卷积层和PReLU激活函数
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.PReLU()
            ))
        # 将特征提取模块列表转换为ModuleList
        self.features = nn.ModuleList(self.features)
        # 定义特征融合模块,包括卷积层和PReLU激活函数
        self.fuse = nn.Sequential(
                nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
                nn.PReLU())

    # 前向传播函数
    def forward(self, x):
        # 获取输入张量的尺寸
        x_size = x.size()
        # 初始化输出列表,第一个元素为输入张量
        out = [x]
        # 遍历特征提取模块
        for f in self.features:
            # 对输入张量进行特征提取,并进行双线性插值恢复到原尺寸
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        # 将所有特征拼接起来,并通过特征融合模块
        out_feat = self.fuse(torch.cat(out, 1))
        return out_feat       

# 定义残差块
class ResidualBlock(nn.Module):
    # 初始化函数,in_channels为输入通道数,out_channels为输出通道数,stride为步长,downsample为下采样模块
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        # 调用父类的初始化函数
        super(ResidualBlock, self).__init__()
        # 第一个3x3卷积层
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        # 定义ReLU激活函数
        self.relu = nn.ReLU(inplace=True)
        # 第二个3x3卷积层
        self.conv2 = conv3x3(out_channels, out_channels)
        # 下采样模块
        self.downsample = downsample
        # 定义LeakyReLU激活函数,负斜率为0.2
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    # 前向传播函数
    def forward(self, x):
        # 保存输入张量作为残差
        residual = x
        # 通过第一个卷积层
        out = self.conv1(x)
        # 通过LeakyReLU激活函数
        out = self.lrelu(out)
        # 通过第二个卷积层
        out = self.conv2(out)
        # 如果有下采样模块,则对输入张量进行下采样
        if self.downsample:
            residual = self.downsample(x)
        # 将残差与输出相加
        out += residual
        # 通过LeakyReLU激活函数
        out = self.lrelu(out)
        return out

# 定义3x3卷积层
def conv3x3(in_channels, out_channels, stride=1):
    # 创建一个3x3卷积层,使用反射填充,不使用偏置
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False,padding_mode='reflect')

CLIP/model.py:定义了CLIP类,这是 CLIP 模型的核心实现。包含视觉和文本两个部分,通过初始化参数构建模型结构,并实现了图像和文本特征的编码以及前向传播方法。

from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

# 定义一个Bottleneck模块,继承自nn.Module,用于构建ResNet中的残差块
class Bottleneck(nn.Module):
    # 定义扩张系数,用于调整输出通道数
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        # 调用父类的构造函数
        super().__init__()

        # 第一个卷积层,使用1x1卷积进行通道数的缩减
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        # 第一个批量归一化层
        self.bn1 = nn.BatchNorm2d(planes)
        # 第一个ReLU激活函数
        self.relu1 = nn.ReLU(inplace=True)

        # 第二个卷积层,使用3x3卷积进行特征提取
        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        # 第二个批量归一化层
        self.bn2 = nn.BatchNorm2d(planes)
        # 第二个ReLU激活函数
        self.relu2 = nn.ReLU(inplace=True)

        # 当步长大于1时,使用平均池化进行下采样,否则使用恒等映射
        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        # 第三个卷积层,使用1x1卷积进行通道数的扩张
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        # 第三个批量归一化层
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        # 第三个ReLU激活函数
        self.relu3 = nn.ReLU(inplace=True)

        # 下采样模块,用于处理输入通道数和输出通道数不匹配的情况
        self.downsample = None
        # 保存步长
        self.stride = stride

        # 当步长大于1或者输入通道数不等于输出通道数时,定义下采样模块
        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # 下采样模块包含平均池化和1x1卷积以及批量归一化
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        # 保存输入作为恒等映射
        identity = x

        # 经过第一个卷积层、批量归一化层和ReLU激活函数
        out = self.relu1(self.bn1(self.conv1(x)))
        # 经过第二个卷积层、批量归一化层和ReLU激活函数
        out = self.relu2(self.bn2(self.conv2(out)))
        # 进行平均池化(如果需要)
        out = self.avgpool(out)
        # 经过第三个卷积层和批量归一化层
        out = self.bn3(self.conv3(out))

        # 如果存在下采样模块,对输入进行下采样
        if self.downsample is not None:
            identity = self.downsample(x)

        # 将输出和恒等映射相加
        out += identity
        # 经过第三个ReLU激活函数
        out = self.relu3(out)
        return out

# 定义一个注意力池化模块,用于对特征图进行池化操作
class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        # 调用父类的构造函数
        super().__init__()
        # 定义位置嵌入,用于为每个位置的特征添加位置信息
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        # 定义键投影层
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        # 定义查询投影层
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        # 定义值投影层
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        # 定义输出投影层,如果没有指定输出维度,则使用嵌入维度
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        # 保存注意力头的数量
        self.num_heads = num_heads

    def forward(self, x):
        # 将输入的特征图从NCHW格式转换为(HW)NC格式
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)
        # 在特征图的开头添加全局平均池化的结果
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)
        # 添加位置嵌入
        x = x + self.positional_embedding[:, None, :].to(x.dtype)
        # 使用多头注意力机制进行特征提取
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]

# 定义一个修改后的ResNet模型,继承自nn.Module
class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        # 调用父类的构造函数
        super().__init__()
        # 保存输出维度
        self.output_dim = output_dim
        # 保存输入分辨率
        self.input_resolution = input_resolution

        # 定义3层的stem卷积层
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # 定义残差层的输入通道数
        self._inplanes = width
        # 构建第一层残差层
        self.layer1 = self._make_layer(width, layers[0])
        # 构建第二层残差层,步长为2
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        # 构建第三层残差层,步长为2
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        # 构建第四层残差层,步长为2
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        # 计算ResNet的特征维度
        embed_dim = width * 32
        # 定义注意力池化层
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

    def _make_layer(self, planes, blocks, stride=1):
        # 构建第一个残差块
        layers = [Bottleneck(self._inplanes, planes, stride)]

        # 更新输入通道数
        self._inplanes = planes * Bottleneck.expansion
        # 构建剩余的残差块
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 定义stem函数,用于处理输入图像
        def stem(x):
            x = self.relu1(self.bn1(self.conv1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))
            x = self.relu3(self.bn3(self.conv3(x)))
            x = self.avgpool(x)
            return x

        # 将输入转换为与第一个卷积层权重相同的数据类型
        x = x.type(self.conv1.weight.dtype)
        # 经过stem函数处理
        x = stem(x)
        # 经过第一层残差层
        x1 = self.layer1(x)
        # 经过第二层残差层
        x2 = self.layer2(x1)
        # 经过第三层残差层
        x3 = self.layer3(x2)
        # 经过第四层残差层
        x4 = self.layer4(x3)
        # 经过注意力池化层
        y = self.attnpool(x4)

        return y, [x, x1, x2, x3, x4]

# 定义一个LayerNorm模块,继承自torch的LayerNorm,用于处理fp16数据类型
class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        # 保存输入的数据类型
        orig_type = x.dtype
        # 将输入转换为float32类型进行计算
        ret = super().forward(x.type(torch.float32))
        # 将输出转换回原始的数据类型
        return ret.type(orig_type)

# 定义一个快速GELU激活函数模块
class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        # 实现快速GELU激活函数
        return x * torch.sigmoid(1.702 * x)

# 定义一个残差注意力块模块,继承自nn.Module
class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        # 调用父类的构造函数
        super().__init__()

        # 定义多头注意力层
        self.attn = nn.MultiheadAttention(d_model, n_head)
        # 定义第一个LayerNorm层
        self.ln_1 = LayerNorm(d_model)
        # 定义多层感知机(MLP)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        # 定义第二个LayerNorm层
        self.ln_2 = LayerNorm(d_model)
        # 保存注意力掩码
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        # 如果存在注意力掩码,将其转换为与输入相同的数据类型和设备
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        # 进行多头注意力计算
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        # 经过第一个LayerNorm层和多头注意力层,并与输入相加
        x = x + self.attention(self.ln_1(x))
        # 经过第二个LayerNorm层和MLP,并与输入相加
        x = x + self.mlp(self.ln_2(x))
        return x

# 定义一个Transformer模块,继承自nn.Module
class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        # 调用父类的构造函数
        super().__init__()
        # 保存Transformer的宽度
        self.width = width
        # 保存Transformer的层数
        self.layers = layers
        # 构建残差注意力块序列
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        # 经过残差注意力块序列
        return self.resblocks(x)

# 定义一个视觉Transformer模块,继承自nn.Module
class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        # 调用父类的构造函数
        super().__init__()
        # 保存输入分辨率
        self.input_resolution = input_resolution
        # 保存输出维度
        self.output_dim = output_dim
        # 定义卷积层,用于将输入图像分割成多个patch
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        # 定义缩放因子
        scale = width ** -0.5
        # 定义类别嵌入
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        # 定义位置嵌入
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        # 定义预LayerNorm层
        self.ln_pre = LayerNorm(width)

        # 定义Transformer模块
        self.transformer = Transformer(width, layers, heads)

        # 定义后LayerNorm层
        self.ln_post = LayerNorm(width)
        # 定义投影层
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        # 经过卷积层,将输入图像分割成多个patch
        x = self.conv1(x)
        # 将patch的特征展平
        x = x.reshape(x.shape[0], x.shape[1], -1)
        # 调整维度顺序
        x = x.permute(0, 2, 1)
        # 添加类别嵌入
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        # 添加位置嵌入
        x = x + self.positional_embedding.to(x.dtype)
        # 经过预LayerNorm层
        x = self.ln_pre(x)

        # 调整维度顺序,以适应Transformer的输入要求
        x = x.permute(1, 0, 2)
        # 经过Transformer模块
        x = self.transformer(x)
        # 调整维度顺序,以适应后续处理
        x = x.permute(1, 0, 2)

        # 取第一个位置的特征
        x = self.ln_post(x[:, 0, :])

        # 如果存在投影层,进行投影操作
        if self.proj is not None:
            x = x @ self.proj

        return x

# 定义一个CLIP模型,继承自nn.Module
class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        # 调用父类的构造函数
        super().__init__()

        # 保存上下文长度
        self.context_length = context_length

        # 如果视觉层是一个元组或列表
        if isinstance(vision_layers, (tuple, list)):
            # 计算视觉注意力头的数量
            vision_heads = vision_width * 32 // 64
            # 定义视觉模型,使用修改后的ResNet
            self.visual = ModifiedResNet(
                # 此处代码未完整,应继续完成视觉模型的初始化

CLIP/clip.py:提供了加载预训练 CLIP 模型和文本分词的功能。包括下载预训练模型、图像预处理和文本分词等方法。

import hashlib  # 导入哈希计算库,用于计算文件的哈希值,以验证文件完整性
import os  # 导入操作系统相关功能库,用于文件和目录操作
import urllib  # 导入URL处理库,用于从网络下载文件
import warnings  # 导入警告处理库,用于发出和处理警告信息
from typing import Any, Union, List  # 导入类型提示相关库,增强代码可读性和可维护性
from pkg_resources import packaging  # 导入版本号处理库,用于比较PyTorch版本

import torch  # 导入PyTorch深度学习库
from PIL import Image  # 导入Pillow库,用于图像处理
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize  # 导入图像转换操作库
from tqdm import tqdm  # 导入进度条库,用于显示下载进度

from .model import build_model  # 从当前目录下的model.py文件中导入build_model函数
from .simple_tokenizer import SimpleTokenizer as _Tokenizer  # 从当前目录下的simple_tokenizer.py文件中导入SimpleTokenizer类,并将其重命名为_Tokenizer

try:
    from torchvision.transforms import InterpolationMode  # 尝试从torchvision.transforms中导入InterpolationMode类
    BICUBIC = InterpolationMode.BICUBIC  # 设置双三次插值模式
except ImportError:
    BICUBIC = Image.BICUBIC  # 如果导入失败,使用Pillow库的双三次插值模式

# 检查PyTorch版本是否低于1.7.1,如果是则发出警告
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
    warnings.warn("PyTorch version 1.7.1 or higher is recommended")

# 定义模块导出的函数和类
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()  # 创建一个SimpleTokenizer对象

# 定义可用的CLIP模型及其对应的下载URL
_MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
    "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}

# 定义下载函数,用于从指定URL下载模型文件到指定根目录
def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)  # 创建根目录,如果目录已存在则不报错
    filename = os.path.basename(url)  # 从URL中提取文件名

    expected_sha256 = url.split("/")[-2]  # 从URL中提取预期的SHA256哈希值
    download_target = os.path.join(root, filename)  # 拼接下载目标文件的完整路径

    # 检查下载目标文件是否存在且不是普通文件,如果是则抛出运行时错误
    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    # 检查下载目标文件是否存在,如果存在则验证其SHA256哈希值
    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target  # 如果哈希值匹配,则返回下载目标文件的路径
        else:
            # 如果哈希值不匹配,则发出警告并重新下载文件
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    # 打开URL并将文件内容写入下载目标文件,同时显示下载进度
    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)  # 每次读取8192字节
                if not buffer:
                    break  # 如果读取到文件末尾,则退出循环

                output.write(buffer)  # 将读取的内容写入文件
                loop.update(len(buffer))  # 更新进度条

    # 再次验证下载文件的SHA256哈希值,如果不匹配则抛出运行时错误
    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

    return download_target  # 返回下载目标文件的路径

# 定义将图像转换为RGB格式的函数
def _convert_image_to_rgb(image):
    return image.convert("RGB")

# 定义图像预处理转换函数,用于将图像调整为指定大小并进行归一化处理
def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),  # 将图像调整为指定大小,使用双三次插值
        CenterCrop(n_px),  # 对图像进行中心裁剪
        _convert_image_to_rgb,  # 将图像转换为RGB格式
        ToTensor(),  # 将图像转换为张量
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),  # 对图像进行归一化处理
    ])

# 定义返回可用CLIP模型名称列表的函数
def available_models() -> List[str]:
    """Returns the names of available CLIP models"""
    return list(_MODELS.keys())

# 定义加载CLIP模型的函数
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
    """Load a CLIP model

    Parameters
    ----------
    name : str
        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

    device : Union[str, torch.device]
        The device to put the loaded model

    jit : bool
        Whether to load the optimized JIT model or more hackable non-JIT model (default).

    download_root: str
        path to download the model files; by default, it uses "~/.cache/clip"

    Returns
    -------
    model : torch.nn.Module
        The CLIP model

    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
    """
    if name in _MODELS:  # 如果指定的模型名称在可用模型列表中
        # 调用_download函数下载模型文件
        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
    elif os.path.isfile(name):  # 如果指定的名称是一个文件路径
        model_path = name  # 直接使用该文件路径作为模型路径
    else:
        # 如果指定的模型名称不存在且不是文件路径,则抛出运行时错误
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    # 打开模型文件
    with open(model_path, 'rb') as opened_file:
        try:
            # 尝试以JIT存档的方式加载模型
            model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
            state_dict = None  # 初始化状态字典为None
        except RuntimeError:
            # 如果加载失败,则以保存的状态字典方式加载模型
            if jit:
                # 如果尝试加载JIT模型失败,则发出警告并将jit设置为False
                warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
                jit = False
            state_dict = torch.load(opened_file, map_location="cpu")  # 加载状态字典

    if not jit:  # 如果不使用JIT模型
        # 根据状态字典构建模型并将其移动到指定设备
        model = build_model(state_dict or model.state_dict()).to(device)
        if str(device) == "cpu":  # 如果设备是CPU
            model.float()  # 将模型参数转换为float类型
        # 返回模型和图像预处理函数
        return model, _transform(model.visual.input_resolution)

    # 对JIT模型的设备名称进行修补
    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
    device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

    def patch_device(module):
        try:
            graphs = [module.graph] if hasattr(module, "graph") else []  # 获取模块的图
        except RuntimeError:
            graphs = []  # 如果获取失败,则将图列表置为空

        if hasattr(module, "forward1"):  # 如果模块有forward1方法
            graphs.append(module.forward1.graph)  # 将forward1方法的图添加到图列表中

        for graph in graphs:  # 遍历图列表
            for node in graph.findAllNodes("prim::Constant"):  # 遍历图中的所有常量节点
                if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):  # 如果节点的值以cuda开头
                    node.copyAttributes(device_node)  # 复制设备节点的属性

    model.apply(patch_device)  # 对模型应用设备修补函数
    patch_device(model.encode_image)  # 对模型的图像编码函数应用设备修补函数
    patch_device(model.encode_text)  # 对模型的文本编码函数应用设备修补函数

    # 在CPU上将模型的数据类型修补为float32
    if str(device) == "cpu":
        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
        float_node = float_input.node()

        def patch_float(module):
            try:
                graphs = [module.graph] if hasattr(module, "graph") else []  # 获取模块的图
            except RuntimeError:
                graphs = []  # 如果获取失败,则将图列表置为空

            if hasattr(module, "forward1"):  # 如果模块有forward1方法
                graphs.append(module.forward1.graph)  # 将forward1方法的图添加到图列表中

            for graph in graphs:  # 遍历图列表
                for node in graph.findAllNodes("aten::to"):  # 遍历图中的所有aten::to节点
                    inputs = list(node.inputs())  # 获取节点的输入列表
                    for i in [1, 2]:  # 遍历输入列表的第1和第2个元素
                        if inputs[i].node()["value"] == 5:  # 如果输入节点的值为5
                            inputs[i].node().copyAttributes(float_node)  # 复制float节点的属性

        model.apply(patch_float)  # 对模型应用数据类型修补函数
        patch_float(model.encode_image)  # 对模型的图像编码函数应用数据类型修补函数
        patch_float(model.encode_text)  # 对模型的文本编码函数应用数据类型修补函数

        model.float()  # 将模型参数转换为float类型

    # 返回模型和图像预处理函数
    return model, _transform(model.input_resolution.item())

# 定义文本分词函数,用于将输入的文本转换为分词后的张量
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
    """
    if isinstance(texts, str):  # 如果输入的文本是字符串
        texts = [texts]  # 将其转换为字符串列表

    sot_token = _tokenizer.encoder["<|startoftext|>"]  # 获取起始标记的编码
    eot_token = _tokenizer.encoder["<|endoftext|>"]  # 获取结束标记的编码
    # 对每个文本进行分词,并添加起始和结束标记
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):  # 如果PyTorch版本低于1.8.0
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)  # 创建一个全零的长整型张量
    else:
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)  # 创建一个全零的整型张量

    for i, tokens in enumerate(all_tokens):  # 遍历所有分词后的文本
        if len(tokens) > context_length:  # 如果分词后的长度超过上下文长度
            if truncate:  # 如果允许截断
                tokens = tokens[:context_length]  # 截断分词后的文本
                tokens[-1] = eot_token  # 将最后一个标记设置为结束标记
            else:
                # 如果不允许截断,则抛出运行时错误
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)  # 将分词后的文本填充到结果张量中

    return result  # 返回结果张量

dataloader_images.py:定义了lowlight_loader类,用于加载低光照图像数据集。支持数据增强操作,如翻转、旋转和缩放等。

import os
import sys

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
from PIL import Image
import glob
import random
import cv2
import clip

# 设置随机种子,确保实验可重复性
random.seed(1143)

# 定义一个函数,用于将变换矩阵的中心偏移到图像中心
def transform_matrix_offset_center(matrix, x, y):
    """Return transform matrix offset center.

    Parameters
    ----------
    matrix : numpy array
        Transform matrix
    x, y : int
        Size of image.

    Examples
    --------
    - See ``rotation``, ``shear``, ``zoom``.
    """
    # 计算图像的中心坐标
    o_x = float(x) / 2 + 0.5
    o_y = float(y) / 2 + 0.5
    # 定义偏移矩阵,将原点移动到图像中心
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    # 定义重置矩阵,将原点移回原来的位置
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    # 计算最终的变换矩阵
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
    return transform_matrix 

# 定义一个函数,用于旋转图像
def img_rotate(img, angle, center=None, scale=1.0):
    """Rotate image.
    Args:
        img (ndarray): Image to be rotated.
        angle (float): Rotation angle in degrees. Positive values mean
            counter-clockwise rotation.
        center (tuple[int]): Rotation center. If the center is None,
            initialize it as the center of the image. Default: None.
        scale (float): Isotropic scale factor. Default: 1.0.
    """
    # 获取图像的高度和宽度
    (h, w) = img.shape[:2]

    # 如果没有指定旋转中心,则将图像中心作为旋转中心
    if center is None:
        center = (w // 2, h // 2)

    # 计算旋转矩阵
    matrix = cv2.getRotationMatrix2D(center, angle, scale)
    # 应用旋转矩阵对图像进行旋转
    rotated_img = cv2.warpAffine(img, matrix, (w, h),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT, borderValue=(0,0,0),)
    return rotated_img

# 定义一个函数,用于对图像进行缩放
def zoom(x, zx, zy, row_axis=0, col_axis=1):
    # 定义缩放矩阵
    zoom_matrix = np.array([[zx, 0, 0],
                            [0, zy, 0],
                            [0, 0, 1]])
    # 获取图像的高度和宽度
    h, w = x.shape[row_axis], x.shape[col_axis]

    # 将缩放矩阵的中心偏移到图像中心
    matrix = transform_matrix_offset_center(zoom_matrix, h, w) 
    # 应用缩放矩阵对图像进行缩放
    x = cv2.warpAffine(x, matrix[:2, :], (w, h),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT, borderValue=(0,0,0),)
    return x

# 定义一个函数,用于对图像进行数据增强
def augmentation(img1,img2):
    # 随机决定是否进行水平翻转
    hflip=random.random() < 0.5
    # 随机决定是否进行垂直翻转
    vflip=random.random() < 0.5
    # 随机决定是否进行90度旋转
    rot90=random.random() < 0.5
    # 随机决定是否进行任意角度旋转
    rot=random.random() <0.3
    # 随机决定是否进行缩放
    zo=random.random()<0.3
    # 随机生成一个旋转角度
    angle=random.random()*180-90
    # 如果需要水平翻转,则对图像进行水平翻转
    if hflip:
        img1=cv2.flip(img1,1)
        img2=cv2.flip(img2,1)
    # 如果需要垂直翻转,则对图像进行垂直翻转
    if vflip:
        img1=cv2.flip(img1,0)
        img2=cv2.flip(img2,0)
    # 如果需要90度旋转,则对图像进行90度旋转
    if rot90:
        img1 = img1.transpose(1, 0, 2)
        img2 = img2.transpose(1,0,2)
    # 如果需要缩放,则对图像进行缩放
    if zo:
        # 定义缩放范围
        zoom_range=(0.7, 1.3)
        # 随机生成缩放因子
        zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
        img1=zoom(img1, zx, zy)
        img2=zoom(img2,zx,zy)
    # 如果需要任意角度旋转,则对图像进行旋转
    if rot:
        img1=img_rotate(img1,angle)
        img2=img_rotate(img2,angle)
    return img1,img2

# 定义一个函数,用于对图像进行预处理和数据增强
def preprocess_aug(img1,img2):
    # 将图像转换为numpy数组,并转换为uint8类型
    img1 = np.uint8((np.asarray(img1)))
    img2 = np.uint8((np.asarray(img2)))
    # 将图像从RGB颜色空间转换为BGR颜色空间
    img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
    img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
    # 对图像进行数据增强
    img1,img2=augmentation(img1,img2)
    # 将图像从BGR颜色空间转换回RGB颜色空间,并转换为PIL图像
    img1 = Image.fromarray(cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
    img2 = Image.fromarray(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
    return img1,img2

# 定义设备,如果有可用的GPU则使用GPU,否则使用CPU
device = "cpu"#"cuda" if torch.cuda.is_available() else "cpu"
# 加载CLIP模型
model, preprocess = clip.load("ViT-B/32", device=device, download_root="./clip_model/")#ViT-B/32
# 冻结CLIP模型的所有参数,不进行训练
for para in model.parameters():
    para.requires_grad = False

# 定义一个函数,用于生成训练图像列表
def populate_train_list(lowlight_images_path,overlight_images_path=None):
    # 如果提供了过亮图像的路径
    if overlight_images_path!=None:
        # 获取低光照图像列表
        image_list_lowlight = glob.glob(lowlight_images_path + "*")
        # 获取过亮图像列表
        image_list_overlight = glob.glob(overlight_images_path + "*")
        # 将低光照图像列表和过亮图像列表合并
        image_list_lowlight += image_list_overlight
    else:
        # 只获取低光照图像列表
        image_list_lowlight = glob.glob(lowlight_images_path + "*")

    # 对图像列表进行排序
    train_list = sorted(image_list_lowlight)
    # 随机打乱图像列表的顺序
    random.shuffle(train_list)

    return train_list

# 定义一个自定义数据集类,用于加载低光照图像
class lowlight_loader(data.Dataset):

    def __init__(self, lowlight_images_path,overlight_images_path=None):
        # 生成训练图像列表
        self.train_list = populate_train_list(lowlight_images_path,overlight_images_path) 
        # 定义图像的大小
        self.size = 512

        # 将训练图像列表赋值给数据列表
        self.data_list = self.train_list
        # 打印训练图像的总数
        print("Total training examples (Backlit):", len(self.train_list))


    def __getitem__(self, index):
        # 获取当前索引对应的图像路径
        data_lowlight_path = self.data_list[index]
        # 打开图像
        data_lowlight = Image.open(data_lowlight_path)

        # 如果图像路径中不包含"result"
        if("result" not in data_lowlight_path):
            # 将图像调整为指定大小
            data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
        # 对图像进行预处理和数据增强
        data_lowlight,_=preprocess_aug(data_lowlight,data_lowlight)
        
        # 将图像转换为numpy数组,并进行归一化处理
        data_lowlight = (np.asarray(data_lowlight)/255.0) 
        # 将numpy数组转换为torch张量,并调整维度
        data_lowlight_output = torch.from_numpy(data_lowlight).float().permute(2,0,1)
        
        return data_lowlight_output,data_lowlight_path

    def __len__(self):
        # 返回数据列表的长度,即训练图像的总数
        return len(self.data_list)

dataloader_prompt_add.py:同样定义了lowlight_loader类,用于加载提示相关的数据集。专门为提示学习(Prompt Learning)设计,在数据加载过程中融入文本提示信息。通过结合图像和文本提示,引导模型学习更符合语义的图像增强策略。

import os
import sys

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
from PIL import Image
import glob
import random
import cv2
import clip

# 定义一个函数,用于将变换矩阵的中心偏移到图像中心
def transform_matrix_offset_center(matrix, x, y):
    """Return transform matrix offset center.

    Parameters
    ----------
    matrix : numpy array
        Transform matrix
    x, y : int
        Size of image.

    Examples
    --------
    - See ``rotation``, ``shear``, ``zoom``.
    """
    # 计算图像中心的x坐标
    o_x = float(x) / 2 + 0.5
    # 计算图像中心的y坐标
    o_y = float(y) / 2 + 0.5
    # 定义一个偏移矩阵,用于将原点移到图像中心
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    # 定义一个重置矩阵,用于将原点移回原来的位置
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    # 计算最终的变换矩阵,通过将偏移矩阵、输入矩阵和重置矩阵相乘
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
    return transform_matrix 

# 定义一个函数,用于旋转图像
def img_rotate(img, angle, center=None, scale=1.0):
    """Rotate image.
    Args:
        img (ndarray): Image to be rotated.
        angle (float): Rotation angle in degrees. Positive values mean
            counter-clockwise rotation.
        center (tuple[int]): Rotation center. If the center is None,
            initialize it as the center of the image. Default: None.
        scale (float): Isotropic scale factor. Default: 1.0.
    """
    # 获取图像的高度和宽度
    (h, w) = img.shape[:2]

    # 如果没有指定旋转中心,则将图像中心作为旋转中心
    if center is None:
        center = (w // 2, h // 2)

    # 计算旋转矩阵
    matrix = cv2.getRotationMatrix2D(center, angle, scale)
    # 应用旋转矩阵对图像进行旋转,使用线性插值和反射填充边界
    rotated_img = cv2.warpAffine(img, matrix, (w, h),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT, borderValue=(0,0,0),)
    return rotated_img

# 定义一个函数,用于对图像进行缩放操作
def zoom(x, zx, zy, row_axis=0, col_axis=1):
    # 定义缩放矩阵
    zoom_matrix = np.array([[zx, 0, 0],
                            [0, zy, 0],
                            [0, 0, 1]])
    # 获取图像的高度和宽度
    h, w = x.shape[row_axis], x.shape[col_axis]

    # 将缩放矩阵的中心偏移到图像中心
    matrix = transform_matrix_offset_center(zoom_matrix, h, w) 
    # 应用缩放矩阵对图像进行缩放,使用线性插值和反射填充边界
    x = cv2.warpAffine(x, matrix[:2, :], (w, h),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT, borderValue=(0,0,0),)
    return x

# 定义一个函数,用于对图像进行数据增强操作
def augmentation(img1,img2):
    # 随机决定是否进行水平翻转
    hflip=random.random() < 0.5
    # 随机决定是否进行垂直翻转
    vflip=random.random() < 0.5
    # 随机决定是否进行90度旋转
    rot90=random.random() < 0.5
    # 随机决定是否进行任意角度旋转
    rot=random.random() <0.3
    # 随机决定是否进行缩放
    zo=random.random()<0.3
    # 随机生成一个旋转角度
    angle=random.random()*180-90
    # 如果需要水平翻转,则对图像进行水平翻转
    if hflip:
        img1=cv2.flip(img1,1)
        img2=cv2.flip(img2,1)
    # 如果需要垂直翻转,则对图像进行垂直翻转
    if vflip:
        img1=cv2.flip(img1,0)
        img2=cv2.flip(img2,0)
    # 如果需要90度旋转,则对图像进行90度旋转
    if rot90:
        img1 = img1.transpose(1, 0, 2)
        img2 = img2.transpose(1,0,2)
    # 如果需要缩放,则对图像进行缩放
    if zo:
        # 定义缩放范围
        zoom_range=(0.5, 1.5)
        # 随机生成缩放因子
        zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
        img1=zoom(img1, zx, zy)
        img2=zoom(img2,zx,zy)
    # 如果需要任意角度旋转,则对图像进行旋转
    if rot:
        img1=img_rotate(img1,angle)
        img2=img_rotate(img2,angle)
    return img1,img2

# 定义一个函数,用于对图像进行预处理和数据增强
def preprocess_aug(img1,img2):
    # 将图像转换为numpy数组,并转换为uint8类型
    img1 = np.uint8((np.asarray(img1)))
    img2 = np.uint8((np.asarray(img2)))
    # 将图像从RGB颜色空间转换为BGR颜色空间
    img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
    img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
    # 对图像进行数据增强
    img1,img2=augmentation(img1,img2)
    # 将图像从BGR颜色空间转换回RGB颜色空间,并转换为PIL图像
    img1 = Image.fromarray(cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
    img2 = Image.fromarray(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
    return img1,img2

# 定义设备,如果有可用的GPU则使用GPU,否则使用CPU
device = "cpu"#"cuda" if torch.cuda.is_available() else "cpu"
# 加载CLIP模型
model, preprocess = clip.load("ViT-B/32", device=device, download_root="./clip_model/")#ViT-B/32
# 冻结CLIP模型的所有参数,不进行训练
for para in model.parameters():
    para.requires_grad = False

# 定义一个函数,用于生成训练图像列表
def populate_train_list(lowlight_images_path,normallight_images_path,overlight_images_path=None):
    # 如果提供了过亮图像的路径
    if overlight_images_path!=None:
        # 获取过亮图像列表
        image_list_overlight = glob.glob(overlight_images_path + "*")
        # 获取低光照图像列表
        image_list_lowlight = glob.glob(lowlight_images_path + "*")
        # 获取正常光照图像列表
        image_list_normallight = glob.glob(normallight_images_path + "*")
        # 将低光照、正常光照和过亮图像列表合并
        train_list = image_list_lowlight+image_list_normallight+image_list_overlight
    else:
        # 获取低光照图像列表
        image_list_lowlight = glob.glob(lowlight_images_path + "*")
        # 获取正常光照图像列表
        image_list_normallight = glob.glob(normallight_images_path + "*")
        # 复制正常光照图像列表
        image_ref_list=image_list_normallight.copy()
        # 复制低光照图像列表
        image_input_list=image_list_lowlight.copy()
        # 如果正常光照或低光照图像列表为空,则抛出异常
        if len(image_list_normallight)==0 or len(image_list_lowlight)==0:
            raise Exception("one of the image lists is empty!", len(image_list_normallight),len(image_list_lowlight))
        # 如果正常光照图像数量少于低光照图像数量
        if len(image_list_normallight)<len(image_list_lowlight):
            # 不断添加正常光照图像,直到数量与低光照图像相同
            while(len(image_ref_list)<len(image_list_lowlight)):
                for i in image_list_normallight:
                    image_ref_list.append(i)
                    if(len(image_ref_list)>=len(image_list_lowlight)):
                        break
        else:
            # 如果低光照图像数量少于正常光照图像数量
            while(len(image_input_list)<len(image_list_normallight)):
                # 不断添加低光照图像,直到数量与正常光照图像相同
                for i in image_list_lowlight:
                    image_input_list.append(i)
                    if(len(image_input_list)>=len(image_list_normallight)):
                        break
        
        # 将低光照和正常光照图像列表合并
        train_list = image_input_list+image_ref_list
    # 随机打乱训练图像列表的顺序
    random.shuffle(train_list)

    return train_list

# 定义一个自定义数据集类,用于加载低光照图像
class lowlight_loader(data.Dataset):

    def __init__(self, lowlight_images_path,normallight_images_path,overlight_images_path=None):
        # 如果提供了过亮图像的路径
        if overlight_images_path!=None:
            # 生成包含低光照、正常光照和过亮图像的训练列表
            self.train_list = populate_train_list(lowlight_images_path,normallight_images_path,overlight_images_path)
        else:
            # 生成包含低光照和正常光照图像的训练列表
            self.train_list = populate_train_list(lowlight_images_path,normallight_images_path)
        # 定义图像的大小
        self.size = 256

        # 将训练图像列表赋值给数据列表
        self.data_list = self.train_list
        # 打印训练图像的总数
        print("Total training examples (max(Backlit,Well-lit)*2):", len(self.train_list))

    def __getitem__(self, index):
        # 获取当前索引对应的图像路径
        data_lowlight_path = self.data_list[index]
        # 打开图像
        data_lowlight = Image.open(data_lowlight_path)
        # 将图像调整为指定大小
        data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
        # 对图像进行预处理和数据增强
        data_lowlight,_=preprocess_aug(data_lowlight,data_lowlight)
        # 将图像转换为numpy数组,并进行归一化处理
        data_lowlight = (np.asarray(data_lowlight)/255.0) 
        # 将numpy数组转换为torch张量
        data_lowlight = torch.from_numpy(data_lowlight).float()
        # 调整张量的维度,并将其移动到指定设备上
        image_lowlight=data_lowlight.permute(2,0,1).to(device)
        # 定义CLIP模型的归一化参数
        clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        # 定义图像缩放操作
        img_resize = transforms.Resize((224,224))
        # 对图像进行缩放
        image2=img_resize(image_lowlight)
        # 对图像进行归一化处理,并调整维度
        image=clip_normalizer((image2.reshape(1,3,224,224)))
        # 使用CLIP模型对图像进行编码,得到图像特征
        image_features = model.encode_image(image)
        # 对图像特征进行归一化处理
        image_features /= image_features.norm(dim=-1, keepdim=True)
        # 如果图像路径中包含特定关键词,则标记为正常光照图像
        if ("normal" in data_lowlight_path)or("output"in data_lowlight_path)or("/data/HQ/"in data_lowlight_path)or("DIV2K"in data_lowlight_path)or("high"in data_lowlight_path):
            label=torch.from_numpy(np.array(1))
        else:
            # 否则标记为低光照图像
            label=torch.from_numpy(np.array(0))
    
        return image_features,label

    def __len__(self):
        # 返回数据列表的长度,即训练图像的总数
        return len(self.data_list)

dataloader_prompt_margin.py 第二阶段数据加载器,支持多阶段训练策略,如逐步引入半监督样本。以下是为代码的每一行添加的注释:

# 导入必要的库
from ctypes import sizeof
import os
import sys

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
from PIL import Image
import glob
import random
import cv2
import clip

# 定义函数,用于将变换矩阵的中心偏移到图像中心
def transform_matrix_offset_center(matrix, x, y):
    """Return transform matrix offset center.

    Parameters
    ----------
    matrix : numpy array
        Transform matrix
    x, y : int
        Size of image.

    Examples
    --------
    - See ``rotation``, ``shear``, ``zoom``.
    """
    # 计算图像的中心点
    o_x = float(x) / 2 + 0.5
    o_y = float(y) / 2 + 0.5
    # 定义偏移矩阵
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    # 定义重置矩阵
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    # 计算最终的变换矩阵
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
    return transform_matrix 

# 定义函数,用于旋转图像
def img_rotate(img, angle, center=None, scale=1.0):
    """Rotate image.
    Args:
        img (ndarray): Image to be rotated.
        angle (float): Rotation angle in degrees. Positive values mean
            counter-clockwise rotation.
        center (tuple[int]): Rotation center. If the center is None,
            initialize it as the center of the image. Default: None.
        scale (float): Isotropic scale factor. Default: 1.0.
    """
    # 获取图像的高度和宽度
    (h, w) = img.shape[:2]

    # 如果未指定旋转中心,则使用图像的中心
    if center is None:
        center = (w // 2, h // 2)

    # 计算旋转矩阵
    matrix = cv2.getRotationMatrix2D(center, angle, scale)
    # 应用旋转变换
    rotated_img = cv2.warpAffine(img, matrix, (w, h),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT, borderValue=(0,0,0),)
    return rotated_img

# 定义函数,用于缩放图像
def zoom(x, zx, zy, row_axis=0, col_axis=1):
    # 定义缩放矩阵
    zoom_matrix = np.array([[zx, 0, 0],
                            [0, zy, 0],
                            [0, 0, 1]])
    # 获取图像的高度和宽度
    h, w = x.shape[row_axis], x.shape[col_axis]

    # 将缩放矩阵的中心偏移到图像中心
    matrix = transform_matrix_offset_center(zoom_matrix, h, w)
    # 应用缩放变换
    x = cv2.warpAffine(x, matrix[:2, :], (w, h),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT, borderValue=(0,0,0),)
    return x

# 定义函数,用于对图像进行增强操作
def augmentation(img,hflip,vflip,rot90,rot,zo,angle,zx,zy):
    # 水平翻转图像
    if hflip:
        img=cv2.flip(img,1)
    # 垂直翻转图像
    if vflip:
        img=cv2.flip(img,0)
    # 旋转90度
    if rot90:
        img = img.transpose(1, 0, 2)
    # 缩放图像
    if zo:
        img=zoom(img, zx, zy)
    # 旋转图像
    if rot:
        img=img_rotate(img,angle)
    return img

# 定义函数,用于对图像列表进行预处理和增强操作
def preprocess_aug(img_list):
    # 随机决定是否进行水平翻转
    hflip=random.random() < 0.5
    # 随机决定是否进行垂直翻转
    vflip=random.random() < 0.5
    # 随机决定是否旋转90度
    rot90=random.random() < 0.5
    # 随机决定是否旋转
    rot=random.random() <0.3
    # 随机决定是否缩放
    zo=random.random()<0.3
    # 随机生成旋转角度
    angle=random.random()*180-90
    # 定义缩放范围
    zoom_range=(0.5, 1.5)
    # 随机生成缩放因子
    zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
    # 初始化增强后的图像列表
    aug_img_list=[]
    # 遍历图像列表
    for img in img_list:
        # 将图像转换为numpy数组并转换为uint8类型
        img = np.uint8((np.asarray(img)))
        # 将图像从RGB颜色空间转换为BGR颜色空间
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        # 对图像进行增强操作
        img=augmentation(img,hflip,vflip,rot90,rot,zo,angle,zx,zy)
        # 将图像从BGR颜色空间转换为RGB颜色空间,并转换为PIL图像
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        # 将增强后的图像添加到列表中
        aug_img_list.append(img)
    return aug_img_list

# 定义设备,使用CPU进行计算
device = "cpu"
# 加载CLIP模型
model, preprocess = clip.load("ViT-B/32", device=device, download_root="./clip_model/")#ViT-B/32
# 冻结CLIP模型的参数,不进行训练
for para in model.parameters():
    para.requires_grad = False

# 定义函数,用于生成训练数据列表
def populate_train_list(lowlight_images_path,normallight_images_path=None,overlight_images_path=None):
    # 获取低光照图像列表
    image_list_lowlight = glob.glob(lowlight_images_path + "*")
    # 获取正常光照图像列表
    image_list_normallight = glob.glob(normallight_images_path+"*")
    
    # 复制正常光照图像列表
    image_ref_list=image_list_normallight.copy()
    # 复制低光照图像列表
    image_input_list=image_list_lowlight.copy()
    # 检查图像列表是否为空
    if len(image_list_normallight)==0 or len(image_list_lowlight)==0:
        raise Exception("one of the image lists is empty!", len(image_list_normallight),len(image_list_lowlight))
    # 如果正常光照图像列表长度小于低光照图像列表长度
    if len(image_list_normallight)<len(image_list_lowlight):
        # 不断添加正常光照图像,直到两个列表长度相等
        while(len(image_ref_list)<len(image_list_lowlight)):
            for i in image_list_normallight:
                image_ref_list.append(i)
                if(len(image_ref_list)>=len(image_list_lowlight)):
                    break
    else:
        # 如果低光照图像列表长度小于正常光照图像列表长度
        while(len(image_input_list)<len(image_list_normallight)):
            for i in image_list_lowlight:
                image_input_list.append(i)
                if(len(image_input_list)>=len(image_list_normallight)):
                    break
    # 定义训练数据列表1
    train_list1=image_input_list
    # 定义训练数据列表2
    train_list2=image_ref_list
    # 随机打乱训练数据列表1
    random.shuffle(train_list1)
    # 随机打乱训练数据列表2
    random.shuffle(train_list2)

    return train_list1,train_list2

# 定义函数,用于预处理图像并提取特征
def preprocess_feature(img):
    # 将图像转换为numpy数组并归一化
    img = (np.asarray(img)/255.0) 
    # 将numpy数组转换为torch张量
    img = torch.from_numpy(img).float()
    # 调整张量的维度
    img=img.permute(2,0,1).to(device)
    # 定义归一化变换
    clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    # 定义缩放变换
    img_resize = transforms.Resize((224,224))
    # 应用缩放变换
    img=img_resize(img)
    # 调整张量的维度并应用归一化变换
    img=clip_normalizer(img.reshape(1,3,224,224))
    # 使用CLIP模型提取图像特征
    image_features = model.encode_image(img)
    # 对图像特征进行归一化
    image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features

# 定义自定义数据集类
class lowlight_loader(data.Dataset):

    def __init__(self, lowlight_images_path,normallight_images_path,semi1_path=None,semi2_path=None):
        # 生成训练数据列表
        self.train_list1,self.train_list2 = populate_train_list(lowlight_images_path,normallight_images_path)
        # 定义图像的大小
        self.size = 256
        # 定义负样本图像路径
        self.neg_path=lowlight_images_path
        # 定义半监督样本1的路径
        self.semi1_path=semi1_path
        # 定义半监督样本2的路径
        self.semi2_path=semi2_path
        # 定义数据列表
        self.data_list = self.train_list1
        # 打印正常光照图像的数量
        print("Total training examples (Well-lit):", len(self.train_list2))
        

    def __getitem__(self, index):
        # 获取低光照图像的路径
        data_lowlight_path = self.data_list[index]
        # 获取参考图像的路径
        ref_path = self.train_list2[index]
        
        # 打开低光照图像
        data_lowlight = Image.open(data_lowlight_path)
        # 打开参考图像
        ref = Image.open(ref_path)

        # 调整低光照图像的大小
        data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
        # 调整参考图像的大小
        ref = ref.resize((self.size,self.size), Image.ANTIALIAS)
        # 如果没有半监督样本1的路径
        if self.semi1_path==None:
            # 对低光照图像和参考图像进行预处理和增强操作
            img_list=preprocess_aug([data_lowlight,ref])
        # 如果没有半监督样本2的路径
        elif self.semi2_path==None:
            # 打开半监督样本1的图像
            semi1 = Image.open(data_lowlight_path.replace(self.neg_path,self.semi1_path).replace('.JPG','.png'))
            # 对低光照图像、半监督样本1和参考图像进行预处理和增强操作
            img_list=preprocess_aug([data_lowlight,semi1,ref])
        else:
            # 打开半监督样本1的图像
            semi1 = Image.open(data_lowlight_path.replace(self.neg_path,self.semi1_path).replace('.JPG','.png'))
            # 打开半监督样本2的图像
            semi2 = Image.open(data_lowlight_path.replace(self.neg_path,self.semi2_path).replace('.JPG','.png'))
            # 对低光照图像、半监督样本1、半监督样本2和参考图像进行预处理和增强操作
            img_list=preprocess_aug([data_lowlight,semi1,semi2,ref])
            
        # 初始化图像特征列表
        img_feature_list=[]
        # 遍历图像列表
        for img in img_list:
            # 预处理图像并提取特征
            img_feature=preprocess_feature(img)
            # 将图像特征添加到列表中
            img_feature_list.append(img_feature)
        
        # 返回图像特征列表和标签
        return img_feature_list,1

    def __len__(self):
        # 返回数据列表的长度
        return len(self.data_list)

simple_tokenizer.py 文件是实现一个简单的文本分词器(tokenizer),它使用字节对编码(Byte Pair Encoding, BPE)技术将文本拆分成词元(tokens),并将词元转换为对应的整数 ID,以便后续的自然语言处理模型能够处理。同时,它也提供了解码功能,将整数 ID 序列转换回原始文本。以下是为代码的每一行添加的注释:

import gzip  # 用于处理 gzip 压缩文件
import html  # 用于处理 HTML 实体编码
import os  # 用于操作系统相关的功能,如文件路径操作
from functools import lru_cache  # 用于实现函数结果的缓存,避免重复计算

import ftfy  # 用于修复文本中的各种编码和格式问题
import regex as re  # 用于正则表达式操作


@lru_cache()  # 使用 lru_cache 缓存函数结果,避免重复读取文件
def default_bpe():
    # 返回默认的 BPE 词汇表文件路径
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")


@lru_cache()  # 使用 lru_cache 缓存函数结果,避免重复计算
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    # 定义一组常见字符的 ASCII 码范围
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]  # 复制字符范围列表
    n = 0
    # 遍历所有 256 个字节
    for b in range(2**8):
        if b not in bs:
            bs.append(b)  # 将未包含的字节添加到列表中
            cs.append(2**8+n)  # 为未包含的字节分配一个唯一的 Unicode 码点
            n += 1
    cs = [chr(n) for n in cs]  # 将 Unicode 码点转换为字符
    return dict(zip(bs, cs))  # 返回字节到 Unicode 字符的映射字典


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()  # 初始化一个空集合用于存储字符对
    prev_char = word[0]  # 获取单词的第一个字符
    for char in word[1:]:  # 遍历单词中的其他字符
        pairs.add((prev_char, char))  # 将相邻的字符对添加到集合中
        prev_char = char  # 更新前一个字符
    return pairs  # 返回字符对集合


def basic_clean(text):
    text = ftfy.fix_text(text)  # 修复文本中的编码和格式问题
    text = html.unescape(html.unescape(text))  # 解码 HTML 实体编码
    return text.strip()  # 去除文本两端的空白字符


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)  # 将连续的空白字符替换为单个空格
    text = text.strip()  # 去除文本两端的空白字符
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()  # 初始化字节到 Unicode 字符的编码器
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}  # 初始化 Unicode 字符到字节的解码器
        # 打开并读取 BPE 合并规则文件
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        # 截取合并规则的有效部分
        merges = merges[1:49152-256-2+1]
        # 将合并规则转换为元组形式
        merges = [tuple(merge.split()) for merge in merges]
        # 初始化词汇表,包含字节到 Unicode 字符映射的所有值
        vocab = list(bytes_to_unicode().values())
        # 为词汇表中的每个字符添加词尾标记 '</w>'
        vocab = vocab + [v+'</w>' for v in vocab]
        # 将合并规则生成的新字符添加到词汇表中
        for merge in merges:
            vocab.append(''.join(merge))
        # 添加特殊标记 '<|startoftext|>' 和 '<|endoftext|>' 到词汇表中
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        # 初始化编码器,将词汇表中的每个字符映射到一个唯一的整数 ID
        self.encoder = dict(zip(vocab, range(len(vocab))))
        # 初始化解码器,将整数 ID 映射回词汇表中的字符
        self.decoder = {v: k for k, v in self.encoder.items()}
        # 初始化合并规则的排序字典
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        # 初始化缓存,用于存储已经处理过的特殊标记
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        # 定义正则表达式模式,用于匹配文本中的各种元素
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:  # 如果 token 已经在缓存中
            return self.cache[token]  # 直接返回缓存中的结果
        # 将 token 转换为元组形式,并为最后一个字符添加词尾标记 '</w>'
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)  # 获取 token 中的字符对集合

        if not pairs:  # 如果没有字符对
            return token+'</w>'  # 直接返回添加词尾标记的 token

        while True:
            # 找到合并优先级最高的字符对
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:  # 如果该字符对不在合并规则中
                break  # 停止合并过程
            first, second = bigram  # 解包字符对
            new_word = []  # 初始化新的单词列表
            i = 0
            while i < len(word):
                try:
                    # 找到字符对中第一个字符的位置
                    j = word.index(first, i)
                    new_word.extend(word[i:j])  # 将之前的字符添加到新单词列表中
                    i = j  # 更新索引
                except:
                    new_word.extend(word[i:])  # 将剩余的字符添加到新单词列表中
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)  # 合并字符对
                    i += 2  # 更新索引
                else:
                    new_word.append(word[i])  # 添加单个字符
                    i += 1
            new_word = tuple(new_word)  # 将新单词列表转换为元组
            word = new_word  # 更新单词
            if len(word) == 1:  # 如果单词只有一个元素
                break  # 停止合并过程
            else:
                pairs = get_pairs(word)  # 重新获取字符对集合
        word = ' '.join(word)  # 将单词转换为字符串
        self.cache[token] = word  # 将结果存入缓存
        return word  # 返回合并后的单词

    def encode(self, text):
        bpe_tokens = []  # 初始化 BPE 词元列表
        # 对文本进行基本清理和空白字符清理,并转换为小写
        text = whitespace_clean(basic_clean(text)).lower()
        # 使用正则表达式匹配文本中的各个元素
        for token in re.findall(self.pat, text):
            # 将 token 中的每个字节转换为对应的 Unicode 字符
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            # 对 token 进行 BPE 编码,并将结果转换为整数 ID 列表
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens  # 返回 BPE 词元的整数 ID 列表

    def decode(self, tokens):
        # 将整数 ID 列表转换为字符列表
        text = ''.join([self.decoder[token] for token in tokens])
        # 将 Unicode 字符转换为字节,并解码为 UTF-8 字符串
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text  # 返回解码后的文本

clip_score.py 文件的主要作用是定义与 CLIP(Contrastive Language-Image Pretraining)模型相关的得分计算函数和损失函数,用于图像增强、图像质量评估等深度学习任务。以下是对代码逐行添加注释后的详细解释:

from turtle import forward
import torchvision.transforms as transforms
import torch
import clip
import torch.nn as nn
from torch.nn import functional as F
from CLIP.clip import load

# 检查是否有可用的CUDA GPU,如果有则使用GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载CLIP模型,这里使用的是ViT-B/32版本,并将其下载到指定目录
model, preprocess = clip.load("ViT-B/32", device=torch.device("cpu"), download_root="./clip_model/")#"ViT-B/32"
# 将模型移动到指定设备(GPU或CPU)
model.to(device)
# 冻结CLIP模型的所有参数,不进行训练
for para in model.parameters():
    para.requires_grad = False

# 定义一个函数,用于计算图像张量与给定文本的CLIP得分
def get_clip_score(tensor,words):
    score = 0
    # 遍历输入张量中的每个图像
    for i in range(tensor.shape[0]):
        # 图像预处理:归一化和调整大小
        clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        img_resize = transforms.Resize((224,224))
        image2 = img_resize(tensor[i])
        image = clip_normalizer(image2).unsqueeze(0)
        # 将文本进行分词并移动到指定设备
        text = clip.tokenize(words).to(device)
        # 使用CLIP模型计算图像和文本的logits
        logits_per_image, logits_per_text = model(image, text)
        # 对图像的logits进行softmax操作,得到概率分布
        probs = logits_per_image.softmax(dim=-1)
        # 取第一个文本对应的概率作为得分
        prob = probs[0][0]
        score = score + prob
    return score

# 定义一个自定义的CLIP损失类
class L_clip(nn.Module):
    def __init__(self):
        super(L_clip,self).__init__()
        # 冻结该损失类的所有参数,不进行训练
        for param in self.parameters(): 
            param.requires_grad = False
  
    def forward(self, x, light):
        # 计算图像与 "dark" 和 "normal light" 文本的CLIP得分
        k1 = get_clip_score(x,["dark","normal light"])
        if light:
            # 如果light为True,计算图像与 "noisy photo" 和 "clear photo" 文本的CLIP得分
            k2 = get_clip_score(x,["noisy photo","clear photo"])
            return (k1 + k2) / 2
        return k1

# 定义一个自定义的提示类,用于学习文本特征
class Prompts(nn.Module):
    def __init__(self, initials=None):
        super(Prompts,self).__init__()
        if initials != None:
            # 如果提供了初始文本,将其分词并移动到GPU
            text = clip.tokenize(initials).cuda()
            with torch.no_grad():
                # 使用CLIP模型编码文本特征
                self.text_features = model.encode_text(text).cuda()
        else:
            # 如果没有提供初始文本,随机初始化文本特征
            self.text_features = torch.nn.init.xavier_normal_(nn.Parameter(torch.cuda.FloatTensor(2,512))).cuda()

    def forward(self, tensor):
        for i in range(tensor.shape[0]):
            image_features = tensor[i]
            # 对文本特征进行归一化
            nor = torch.norm(self.text_features, dim=-1, keepdim=True)
            # 计算图像特征与文本特征的相似度,并进行softmax操作
            similarity = (model.logit_scale.exp() * image_features @ (self.text_features / nor).T).softmax(dim=-1)
            if i == 0:
                probs = similarity
            else:
                probs = torch.cat([probs, similarity], dim=0)
        return probs

# 初始化一个Prompts类的实例,并将其移动到GPU
learn_prompt = Prompts().cuda()
# 定义图像归一化和调整大小的变换
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
img_resize = transforms.Resize((224,224))

# 定义一个函数,用于从图像特征计算CLIP得分
def get_clip_score_from_feature(tensor, text_features):
    score = 0
    for i in range(tensor.shape[0]):
        # 对图像进行预处理
        image2 = img_resize(tensor[i])
        image = clip_normalizer(image2.reshape(1,3,224,224))
        # 使用CLIP模型编码图像特征
        image_features = model.encode_image(image)
        # 对图像特征和文本特征进行归一化
        image_nor = image_features.norm(dim=-1, keepdim=True)
        nor = text_features.norm(dim=-1, keepdim=True)
        # 计算图像特征与文本特征的相似度,并进行softmax操作
        similarity = (100.0 * (image_features / image_nor) @ (text_features / nor).T).softmax(dim=-1)
        probs = similarity
        prob = probs[0][0]
        score = score + prob
    # 计算平均得分
    score = score / tensor.shape[0]
    return score

# 定义一个自定义的从特征计算CLIP损失的类
class L_clip_from_feature(nn.Module):
    def __init__(self):
        super(L_clip_from_feature,self).__init__()
        # 冻结该损失类的所有参数,不进行训练
        for param in self.parameters(): 
            param.requires_grad = False
  
    def forward(self, x, text_features):
        # 计算图像特征与文本特征的CLIP得分
        k1 = get_clip_score_from_feature(x, text_features)
        return k1

# 加载另一个CLIP模型(RN101),用于计算重建损失
res_model, res_preprocess = load("RN101", device=device, download_root="./clip_model/")
# 冻结RN101模型的所有参数,不进行训练
for para in res_model.parameters():
    para.requires_grad = False

# 定义一个函数,用于计算L2损失
def l2_layers(pred_conv_features, input_conv_features, weight):
    # 将权重转换为与特征相同的数据类型
    weight = torch.tensor(weight).type(pred_conv_features[0].dtype)
    # 计算预测特征和输入特征之间的L2损失
    return weight @ torch.tensor([torch.square(x_conv - y_conv).mean() for x_conv, y_conv in
            zip(pred_conv_features, input_conv_features)], requires_grad=True) / len(weight)

# 定义一个函数,用于计算CLIP的MSE损失
def get_clip_score_MSE(pred, inp, weight):
    score = 0
    for i in range(pred.shape[0]):
        # 对预测图像进行预处理
        pred_img = img_resize(pred[i])
        pred_img = clip_normalizer(pred_img.reshape(1,3,224,224))
        # 使用RN101模型编码预测图像的特征
        pred_image_features = res_model.encode_image(pred_img)

        # 对输入图像进行预处理
        inp_img = img_resize(inp[i])
        inp_img = clip_normalizer(inp_img.reshape(1,3,224,224))
        # 使用RN101模型编码输入图像的特征
        inp_image_features = res_model.encode_image(inp_img)
        
        MSE_loss_per_img = 0
        for feature_index in range(len(weight)):
            # 计算每个特征的MSE损失
            MSE_loss_per_img = MSE_loss_per_img + weight[feature_index] * F.mse_loss(pred_image_features[1][feature_index].squeeze(0), inp_image_features[1][feature_index].squeeze(0))
        score = score + MSE_loss_per_img
    return score

# 定义一个自定义的CLIP MSE损失类
class L_clip_MSE(nn.Module):
    def __init__(self):
        super(L_clip_MSE,self).__init__()
        # 冻结该损失类的所有参数,不进行训练
        for param in self.parameters(): 
            param.requires_grad = False
        
    def forward(self, pred, inp, weight=[1.0,1.0,1.0,1.0,0.5]):
        # 计算预测图像和输入图像的CLIP MSE损失
        res = get_clip_score_MSE(pred, inp, weight)
        return res

# 定义一个自定义的四边际损失类
class four_margin_loss(nn.Module):
    def __init__(self, dis1=0.7, dis2=0.3):
        super(four_margin_loss, self).__init__()
        # 初始化两个MarginRankingLoss实例,分别使用不同的边际值
        self.margin_loss_L = nn.MarginRankingLoss(dis1)
        self.margin_loss_S = nn.MarginRankingLoss(dis2)
        # 初始化一个L_clip_from_feature损失实例
        self.clip_loss = L_clip_from_feature()
    
    def forward(self, tensor0, tensor3, labels, num, *tensor_mid):
        # 计算输入和参考之间的边际损失
        loss_inp_ref = self.margin_loss_L(tensor0, tensor3, labels)
        if num == 2:
            print(tensor0, tensor3)
            return loss_inp_ref
        elif num == 3:
            print(tensor0, tensor_mid, tensor3)
            # 计算输入和半监督样本1之间的边际损失
            loss_inp_semi1 = self.margin_loss_L(tensor0, tensor_mid[0], labels)
            # 计算半监督样本1和参考之间的边际损失
            loss_semi1_ref = self.margin_loss_S(tensor_mid[0], tensor3, labels)
            return loss_inp_ref + loss_inp_semi1 + loss_semi1_ref
        elif num == 4:
            print(tensor0, tensor_mid, tensor3)
            # 计算输入和半监督样本1之间的边际损失
            loss_inp_semi1 = self.margin_loss_L(tensor0, tensor_mid[0], labels)
            # 计算半监督样本1和半监督样本2之间的边际损失
            loss_semi1_semi2 = self.margin_loss_S(tensor_mid[0], tensor_mid[1], labels)
            # 计算半监督样本2和参考之间的边际损失
            loss_semi2_ref = self.margin_loss_S(tensor_mid[1], tensor3, labels)
            return loss_inp_ref + loss_inp_semi1 + loss_semi1_semi2 + loss_semi2_ref

5、小结

本文是CLIP-LIT这种无监督暗光增强学习方式的代码解读;该方案无需大量配对的训练数据;利用 CLIP 模型的语义信息,能够生成更符合人类视觉感知的增强图像;通过迭代提示学习,不断优化图像增强效果。


以上针对于CLIP-LIT的代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值