[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - 模型训练篇

参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE
前情提要
有关多模态大模型架构中的语言模型部分(MQwen.py)的代码请看(多模态大模型源码阅读 - 1、 多模态大模型源码阅读 - 2, 多模态大模型源码阅读 - 3,多模态大模型源码阅读 - 4)
多模态大模型架构中的视觉模型(visual/CLIP-VIT.py)部分请看多模态大模型源码阅读 - 5
多模态大模型架构中的trainer(trainer.py)部分请看多模态大模型源码阅读 - 6
多模态大模型架构中的MultiModal融合部分(MultiModal.py)部分请看多模态大模型源码阅读 - MultiModal篇。
多模态大模型架构中的Dataset部分请看多模态大模型源码阅读 - Dataset篇
观前提醒,本文中介绍的多模态模型架构来源于github项目WatchTower-Liu/VLM-learning,对Qwen模型的前向传播代码进行重写,并通过中间投影层将视觉特征与文本映射到同一向量空间。投影层原理参考LLAVA
本节介绍的是模型训练部分。将视觉模型的参数冻结,并采用LoRA对语言模型进行微调。训练参数包括语言模型中LoRA的参数和中间投影层参数。
其中投影层参数为初始化参数,为了平衡模型参数优化速度,这里为映射层设定了比Lora部分更大的学习率。
源码阅读
完整代码
import os
import json
import torch
from typing import Optional
from functools import partial
from trainer import MultiModalTrainer
from model.model import MMultiModal, LanguageConfig, VisualConfig, MultiModalConfig
from dataset.image_caption_dataset import ImageCaptionDataset, data_collate
import transformers
from transformers import HfArgumentParser, AutoTokenizer
from dataclasses import dataclass, field
from qwen.modeling_qwen import QWenLMHeadModel
from accelerate import Accelerator
# from peft import LoraConfig, TaskType, get_peft_model, PeftModel
# from einops import rearrange
@dataclass
class FinetuneArguments:
lora_rank: int = field(default=8)
lora_dropout: float = field(default=0.1)
previous_lora_weights: Optional[str] = field(default=None)
target_modules: str = field(default="W_pack")
image_map: str = field(default="data/image_map_b.json", metadata={
"help": "图像文件与索引ID"})
captions_file: str = field(default="data/captions_b.json", metadata={
"help": "ID与caption的对应"})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
feature_proj_lr: Optional[float] = None
def train():
finetune_args, training_args = HfArgumentParser(
(FinetuneArguments, TrainingArguments)
).parse_args_into_dataclasses()
base_language_model = "Qwen/Qwen-7B-Chat"
# base_language_model = "openbmb/MiniCPM-2B-history"
# base_value_model = "openai/clip-vit-large-patch14"
base_value_model = "google/siglip-so400m-patch14-384"
tokenizer = AutoTokenizer.from_pretrained(base_language_model, trust_remote_code=True)
replace_token_id = tokenizer.convert_tokens_to_ids("<|extra_0|>")
# Check file paths
if not os.path.exists(finetune_args.image_map):
raise FileNotFoundError(f"Image map file not found: {
finetune_args.image_map}")
if not os.path.exists(finetune_args.captions_file):
raise FileNotFoundError(f"Captions file not found: {
finetune_args.captions_file}")
# Load and check file contents
with open(finetune_args