Lora(Low-Rank Adaptation) 是一种高效的模型微调方法,尤其在自然语言处理和生成任务中得到了广泛应用。与传统的微调方法相比,Lora方法通过低秩适配器的方式,使得模型微调更为高效,且资源消耗更少。本文将介绍如何使用微调后的Lora参数,并为您提供一些实践技巧,帮助您在实际应用中更好地利用这一技术。
一、什么是Lora?
Lora是一种轻量级的模型适配技术,旨在减少微调大规模预训练模型时的计算和存储需求。 通过引入低秩矩阵的适配器,Lora能够在不改变原始预训练模型的基础上,通过学习适当的低秩矩阵来调整模型的参数。这样,Lora使得微调过程中只需要调整少量的参数,极大地降低了微调的成本。
二、为什么选择Lora?
使用Lora进行微调的主要优势包括:
- 计算效率:传统的微调方法需要调整模型的大量参数,而Lora只调整低秩适配器的参数,显著减少了计算开销。
- 存储节省:Lora微调后需要存储的参数较少,尤其在处理大规模模型时,显著降低了存储需求。
- 适应性强:Lora能够适应不同任务和数据集,尤其适合需要频繁更新的任务。
三、如何使用微调后的Lora参数?
3.1 LLaMA-Factory
直接把lora训练后的参数合并到基座模型中,得到新的大模型进行使用。可以使用LLaMA-Factory实现。
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export \
--model_name_or_path "model/Qwen2.5-7B-Instruct/" \
--adapter_name_or_path "LLaMA-Factory-main/saves/Qwen2.5-7B-Instruct/lora/train_2024-12-13-09-33-24" \
--template qwen \
--finetuning_type lora \
--export_dir Qwen2.5-7B-Instruct-lora/ \
--export_size 2 \
--export_legacy_format False
3.2 PeftModel
使用PeftModel加载基座模型和lora适配器
from transformers import AutoModelForCausalLM,AutoTokenizer
from peft import PeftModel
base_model_name = "model/Qwen2.5-7B-Instruct/"
lora_model_path = "Qwen2.5-7B-Instruct/lora/train_2024-12-13-09-33-24"
device = "cuda"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model,lora_model_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
input_text = "你是谁?"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
output = model.generate(**inputs, max_length=100, num_return_sequences=1)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
3.3 vllm
使用vllm加载基座模型和lora适配器,并做批量推理
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
base_model_name = "model/Qwen2.5-7B-Instruct/"
lora_model_path = "Qwen2.5-7B-Instruct/lora/train_2024-12-13-09-33-24"
llm = LLM(model=base_model_name , enable_lora=True)
sampling_params = SamplingParams(
temperature=0.3,
max_tokens=256,
)
prompts = [
"query1",
"query2",
"query3"
]
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("lora1", 1, lora_model_path )
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")