深度学习之用CelebA_Spoof数据集搭建一个活体检测-模型搭建和训练

上一篇:深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理,我们使用了CelebA_Spoof数据集进行了处理,目的是为了一个2D的活体检测。本文将详细介绍如何使用CelebA_Spoof数据集训练一个2D活体检测模型,包括模型架构设计、分布式训练实现和关键代码解析。

1. 项目概述

对于活体检测,个人认为一直是个伪命题,在人脸识别中,其实真正意义是不管这个人脸是真还是假,随着人脸识别技术的应用越来越广泛,人脸识别技术已经不再是一种技术,而是技术意外的一个比较重要的工具了,那么安全性就必须得到提高,活体检测自热而然就有了它的用武之地(本人是很不情愿做这个活体检测的)。如何设计一个活体检测的模型呢,当然现在这个领域已经有了很多的研究和很好的方法,在现有的实际应用中已经完全可以安全使用。但是作为入门者,还是得从简单的步骤做起来。
在这里,我暂且把2D的活体检测当作一个二分类的问题去解决(实际中,活体检测很复杂),考虑到要移植到嵌入式设备上,所以用一个小模型去作为二分类的主体,当然有人说为了模型的小,随便搭几个卷积层就行了,但是,凭什么你自己搭建的能有别人团队开源的好呢?所以拿来主义打败了自己。本项目采用知识蒸馏技术,使用ResNet18作为教师模型,SqueezeNet1.0作为学生模型,在CelebA_Spoof数据集上进行训练。

2. 环境配置

2.1 硬件要求

  • 多GPU服务器(建议2-8块GPU),当然一块也行,毕竟二分类,而且数据量不是很大
  • CUDA 11.0及以上

2.2 软件依赖

pip install torch torchvision#这是主要的工具

2.3 数据预处理

可以参照之前的文章,深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理,可以用数据集里面的json文件,也可以自己重新构建数据集。在这里我是重新构建了:

import os
from PIL import Image
from pathlib import Path
from PIL import ImageFile
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

ImageFile.LOAD_TRUNCATED_IMAGES = True
print_lock = threading.Lock()

def load_bbox(bbox_file_path):
    """加载边界框文件"""
    if not os.path.exists(bbox_file_path):
        raise FileNotFoundError(f"Bounding box file not found: {bbox_file_path}")
    with open(bbox_file_path, 'r') as file:
        return list(map(float, file.read().strip().split()))

def process_image(args):
    """处理单张图片(线程安全)"""
    img_path, output_path, thread_id, remaining = args
    try:
        with print_lock:
            print(f"[Thread-{thread_id}] Processing {Path(img_path).name} ({remaining} remaining)")
            
        bbox_file = os.path.join(os.path.dirname(img_path), f"{Path(img_path).stem}_BB.txt")
        bbox = load_bbox(bbox_file)
        img = Image.open(img_path)
        real_w, real_h = img.size
        
        # 计算实际边界框
        x1 = int(bbox[0] * (real_w / 224))
        y1 = int(bbox[1] * (real_h / 224))
        w1 = int(bbox[2] * (real_w / 224))
        h1 = int(bbox[3] * (real_h / 224))
        
        # 裁剪并保存
        cropped_img = img.crop((x1, y1, x1 + w1, y1 + h1))
        cropped_img.save(output_path)
        
        #with print_lock:
        #    print(f"[Thread-{thread_id}] Finished {Path(img_path).name}")
        return True
    except Exception as e:
        with print_lock:
            print(f"[Thread-{thread_id}] Error processing {Path(img_path).name}: {str(e)}")
        return False

def process_directory(input_dir, output_dir, max_workers=8):
    """使用线程池处理目录中的所有图片"""
    os.makedirs(output_dir, exist_ok=True)
    tasks = []
    
    # 获取所有图片文件
    image_files = [f for f in os.listdir(input_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
    total = len(image_files)
    
    for idx, img_file in enumerate(image_files):
        img_path = os.path.join(input_dir, img_file)
        output_file = Path(img_file).with_suffix(".png")
        output_path = os.path.join(output_dir, output_file)
        remaining = total - idx - 1
        tasks.append((img_path, output_path, idx % max_workers, remaining))
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_image, task) for task in tasks]
        for future in as_completed(futures):
            future.result()

def process_data(root_dir, output_base_dir):
    """处理整个数据集目录"""
    for split in ["test", "train"]:
        split_dir = os.path.join(root_dir, split)
        if not os.path.exists(split_dir):
            print(f"Warning: {split_dir} does not exist")
            continue
            
        # 创建对应的输出目录
        output_split_dir = os.path.join(output_base_dir, split)
        os.makedirs(os.path.join(output_split_dir, "live"), exist_ok=True)
        os.makedirs(os.path.join(output_split_dir, "spoof"), exist_ok=True)
            
        for sub_dir in os.listdir(split_dir):
            sub_path = os.path.join(split_dir, sub_dir)
            if not os.path.isdir(sub_path):
                print(f"Skipping non-directory: {sub_path}")
                continue
                
            for label in ["live", "spoof"]:
                label_path = os.path.join(sub_path, label)
                if not os.path.exists(label_path):
                    print(f"Warning: {label_path} does not exist")
                    continue
                    
                # 打印处理的目录信息
                print(f"Processing directory: {label_path}")
                image_files = [f for f in os.listdir(label_path) if f.endswith((".png", ".jpg", ".jpeg"))]
                print(f"Found {len(image_files)} images in {label_path}")
                
                # 输出到对应的split/live或split/spoof目录
                output_dir = os.path.join(output_split_dir, label)
                process_directory(label_path, output_dir, 24)

if __name__ == "__main__":
    input_root = "/CelebA_Spoof/CelebA_Spoof/Data"
    output_root = "/CelebA_Spoof_train"
    process_data(input_root, output_root)

代码重新遍历原图的文件目录,根据不同人的ID下的"live", “spoof”,经过每个图片对应的标签处理后得到原图的ROI,保存在对应的"live", "spoof"下,具体可以自己琢磨。

3. 代码结构

开始下代码,这是代码的结构

live_spoof/
├── train.py                 # 主训练脚本
└── utils/
    ├── dist_utils.py        # 分布式训练工具
    ├── data_utils.py        # 数据加载工具  
    ├── model_utils.py       # 模型构建工具
    └── train_utils.py       # 训练逻辑实现

4. 核心实现

4.1 模型架构 (model_utils.py)

偷懒,并且直接用知识蒸馏技术,使用ResNet18作为教师模型,SqueezeNet1.0作为学生模型,简单方便,当然,直接卷,效果不一定好,在这里只是作为参考,后续可以自己加入其他特征干扰。

import torch
import torch.nn as nn
from torchvision import models

class ModelFactory:
    @staticmethod
    def build_teacher():
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = nn.Linear(model.fc.in_features, 2)
        return model
    
    @staticmethod
    def build_student():
        model = models.squeezenet1_0(weights=models.SqueezeNet1_0_Weights.DEFAULT)
        model.classifier[1] = nn.Conv2d(512, 2, kernel_size=(1,1))
        return model

    @staticmethod
    def wrap_model(model, device, gpu):
        """将模型包装为分布式并行模型"""
        model = model.to(device)
        return torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

4.2 分布式训练初始化 (dist_utils.py)

目前用的是nccl方式,有其他好的方式请告知我,不尽感激。

import os
import torch
import torch.distributed as dist
import datetime

def setup_distributed(gpu, args):
    """初始化分布式训练环境"""
    try:
        dist.init_process_group(
            backend='nccl',
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=gpu,
            timeout=datetime.timedelta(seconds=600)
        )
        torch.cuda.set_device(gpu)
        device = torch.device(f'cuda:{gpu}')
        return device
    except RuntimeError as e:
        print(f"[Rank {gpu}] Distributed init failed: {str(e)}")
        raise

4.3 数据加载 (data_utils.py)

直接输入处理后对应的数据目录

from torchvision import datasets
from torch.utils.data import DataLoader, DistributedSampler

class DataLoaderFactory:
    def __init__(self, args):
        self.args = args
        self.num_workers = 8
        
    def create_loaders(self, train_dir, test_dir, transform):
        """创建分布式数据加载器"""
        train_set = datasets.ImageFolder(train_dir, transform=transform)
        test_set = datasets.ImageFolder(test_dir, transform=transform)
        
        train_sampler = DistributedSampler(
            train_set, num_replicas=self.args.world_size, rank=self.args.rank)
        test_sampler = DistributedSampler(
            test_set, num_replicas=self.args.world_size, rank=self.args.rank, shuffle=False)
            
        train_loader = DataLoader(
            train_set, batch_size=self.args.batch_size,
            sampler=train_sampler, num_workers=self.num_workers, pin_memory=True)
            
        test_loader = DataLoader(
            test_set, batch_size=self.args.batch_size,
            sampler=test_sampler, num_workers=self.num_workers, pin_memory=True)
            
        return train_loader, test_loader, len(train_set), len(test_set)

4.4 训练逻辑 (train_utils.py)

常规的训练逻辑,损失以及蒸馏温度都是正常设置,并无其他说明,后续如果效果不好,需要重点调参。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# ... 保留原有导入 ...
from torchvision import transforms
import datetime

class Trainer:
    def __init__(self, student, teacher, device, gpu, args):
        """初始化训练组件"""
        self.student = student
        self.teacher = teacher.eval()  # 教师模型固定为评估模式
        self.device = device
        self.gpu = gpu
        self.args = args
        
        # 优化器配置
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.optimizer = optim.Adam(student.parameters(), lr=args.lr)
        self.scheduler = StepLR(self.optimizer, 
                               step_size=args.step_size,
                               gamma=args.gamma)
        self.best_acc = 0.0
        
        # 模型保存配置
        self.model_dir = "models"
        os.makedirs(self.model_dir, exist_ok=True)

    def _compute_loss(self, student_out, teacher_out, targets):
        """计算蒸馏损失"""
        # KL散度损失
        kl_loss = nn.KLDivLoss(reduction='batchmean')(
            torch.log_softmax(student_out/self.args.T, dim=1),
            torch.softmax(teacher_out/self.args.T, dim=1)
        ) * (self.args.T ** 2)
        
        # 交叉熵损失
        ce_loss = self.criterion(student_out, targets)
        
        return self.args.alpha * kl_loss + (1 - self.args.alpha) * ce_loss

    def train_epoch(self, train_loader, epoch):
        """完整训练逻辑"""
        self.student.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            # 前向传播
            student_out = self.student(data)
            with torch.no_grad():
                teacher_out = self.teacher(data)
            
            # 计算损失
            loss = self._compute_loss(student_out, teacher_out, target)
            
            # 反向传播
            loss.backward()
            self.optimizer.step()
            
            # 统计指标
            total_loss += loss.item()
            _, predicted = student_out.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            # 主进程打印日志
            if self.gpu == 0 and batch_idx % 200 == 0:
                avg_loss = total_loss / (batch_idx + 1)
                acc = 100. * correct / total
                print(f'Epoch {epoch} Batch {batch_idx}/{len(train_loader)} '
                      f'Loss: {avg_loss:.4f} | Acc: {acc:.2f}%')
        
        return {
            'loss':total_loss / len(train_loader),
            'accuracy': 100. * correct / total
            }

    def validate(self, val_loader):
        """完整验证逻辑"""
        self.student.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                outputs = self.student(data)
                loss = self.criterion(outputs, target)
                
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        
        accuracy = 100. * correct / total
        avg_loss = total_loss / len(val_loader)
        
        if self.gpu == 0:
            print(f'Validation Loss: {avg_loss:.4f} | Acc: {accuracy:.2f}%')
        
        return accuracy

    def save_checkpoint(self, epoch, acc, is_best=False):
        """改进的模型保存方法"""
        if self.gpu != 0:  # 仅主进程保存
            return
        
        state = {
            'epoch': epoch,
            'accuracy': acc,
            'student_state_dict': self.student.module.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict()
        }
        
        filename = f"checkpoint_epoch_{epoch}.pth" if not is_best else "best_model.pth"
        save_path = os.path.join(self.model_dir, filename)
        
        torch.save(state, save_path)
        print(f"[{datetime.datetime.now()}] 模型已保存至 {save_path}")

def setup_transform(augment=False):
    """增强的数据预处理流程"""
    transform = [
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ]
    
    if augment:  # 训练集增强
        transform.insert(1, transforms.RandomHorizontalFlip())
        transform.insert(2, transforms.ColorJitter(0.1, 0.1, 0.1))
    
    return transforms.Compose(transform)

4.4 训练主函数(train.py)

训练的主入口

import os
import torch
import argparse
import socket
import datetime
import numpy as np
import torch.multiprocessing as mp
from utils.dist_utils import setup_distributed
from utils.data_utils import DataLoaderFactory
from utils.model_utils import ModelFactory
from utils.train_utils import Trainer, setup_transform

def main_worker(gpu, ngpus_per_node, args):
    # 初始化分布式环境
    if gpu == 0:
        print(f"\n=== 初始化阶段 ===")
        print(f"主节点: GPU{gpu} | 总GPU数: {args.world_size}")
    
    try:
        device = setup_distributed(gpu, args)
        if gpu == 0:
            print(f"[GPU{gpu}] 分布式环境初始化成功 | 设备: {device}")
    except Exception as e:
        print(f"[GPU{gpu}] 初始化失败: {str(e)}")
        return

    # 数据加载
    if gpu == 0:
        print(f"\n=== 数据加载阶段 ===")
        print(f"训练数据目录: {args.train_data_dir}")
        print(f"测试数据目录: {args.test_data_dir}")
    
    loader_factory = DataLoaderFactory(args)
    train_loader, val_loader, train_size, val_size = loader_factory.create_loaders(
        args.train_data_dir, 
        args.test_data_dir,
        setup_transform()
    )
    
    if gpu == 0:
        print(f"[GPU{gpu}] 数据加载完成 | 训练集: {train_size:,} | 测试集: {val_size:,}")
        print(f"Batch size: {args.batch_size} | Workers: {loader_factory.num_workers}")

    # 模型初始化
    if gpu == 0:
        print(f"\n=== 模型初始化 ===")
    
    teacher = ModelFactory.wrap_model(ModelFactory.build_teacher(), device, gpu).eval()
    student = ModelFactory.wrap_model(ModelFactory.build_student(), device, gpu)
    
    if gpu == 0:
        params = sum(p.numel() for p in student.parameters()) / 1e6
        print(f"[GPU{gpu}] 模型初始化完成 | 学生模型参数量: {params:.2f}M")
        print(f"教师模型: {type(teacher.module).__name__}")
        print(f"学生模型: {type(student.module).__name__}")

    # 训练初始化
    if gpu == 0:
        print(f"\n=== 训练配置 ===")
        print(f"学习率: {args.lr} | 蒸馏温度: {args.T}")
        print(f"优化器: Adam | 调度器: StepLR(每{args.step_size}轮衰减{args.gamma}x)")
    
    trainer = Trainer(
        student=student, 
        teacher=teacher, 
        device=device, 
        gpu=gpu, 
        args=args
    )

    # 主训练循环
    if gpu == 0:
        print(f"\n=== 开始训练 ===")
        print(f"总轮次: {args.epochs} | 保存间隔: {args.save_interval}轮")
    
    for epoch in range(1, args.epochs + 1):
        if gpu == 0:
            print(f"\nEpoch {epoch}/{args.epochs} - LR: {trainer.scheduler.get_last_lr()[0]:.6f}")
        # 训练阶段
        train_metrics = trainer.train_epoch(train_loader, epoch)
        train_loss, train_acc = train_metrics['loss'], train_metrics['accuracy']
        # 验证阶段
        val_acc = trainer.validate(val_loader)
        if gpu == 0:
            print(f"[Epoch {epoch}] 训练Loss: {train_loss:.4f} | 训练Acc: {train_acc:.2f}% | 验证Acc: {val_acc:.2f}%")
        
        # 模型保存
        if gpu == 0:
            trainer.save_checkpoint(epoch, val_acc)

if __name__ == '__main__':
    # 参数解析
    parser = argparse.ArgumentParser(description='2D活体检测知识蒸馏训练')
    # ... [保持原有参数定义不变] ...
    parser.add_argument('--seed', type=int, default=42, 
                      help='Global random seed (default: 42)')    
    parser.add_argument('--train-data-dir', type=str, required=True, 
                      help='Path to the training dataset directory')
    parser.add_argument('--test-data-dir', type=str, required=True,
                      help='Path to the test dataset directory')
    parser.add_argument('--world-size', default=-1, type=int, help='Number of GPUs to use')
    parser.add_argument('--dist-url', default='env://', type=str, help='URL used to set up distributed training')
    parser.add_argument('--rank', default=0, type=int, help='Rank of the current process')
    parser.add_argument('--lr', default=0.001, type=float, help='Initial learning rate')
    parser.add_argument('--step-size', default=10, type=int, help='Period of learning rate decay')
    parser.add_argument('--gamma', default=0.1, type=float, help='Multiplicative factor of learning rate decay')
    parser.add_argument('--alpha', default=0.5, type=float, help='Weight for knowledge distillation loss')
    parser.add_argument('--T', default=4, type=float, help='Temperature for knowledge distillation')
    parser.add_argument('--epochs', default=50, type=int, help='Total number of epochs to train')
    parser.add_argument('--log-interval', default=5, type=int, help='Interval between logging epochs')
    parser.add_argument('--batch-size', default=32, type=int, help='Batch size for training and validation')
    parser.add_argument('--save-interval', default=10, type=int, help='Interval between saving checkpoints')
    args = parser.parse_args()
    # 自动计算world_size
    if args.world_size == -1:
        args.world_size = torch.cuda.device_count()
        print(f"自动检测到可用GPU数量: {args.world_size}")
    # 动态端口分配
    def find_free_port():
        with socket.socket() as s:
            s.bind(('', 0))
            return s.getsockname()[1]
    
    os.environ['MASTER_PORT'] = str(find_free_port())
    args.dist_url = f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}'    
    # 分布式训练启动
    mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args))

5. 训练执行

5.1 启动训练

#!/bin/bash

# 设置日志文件路径和名称
LOG_DIR="/logs"
mkdir -p $LOG_DIR
LOG_FILE="$LOG_DIR/train_$(date +%Y%m%d_%H%M%S).log"

# 设置环境变量
export WORLD_SIZE=2
export MASTER_ADDR=localhost
export MASTER_PORT=29500

# 启动分布式训练并记录日志
{
    echo "========== 训练开始 =========="
    echo "时间: $(date)"
    echo "使用的GPU: ${CUDA_VISIBLE_DEVICES:-未指定,使用所有可用GPU}"
    echo "训练参数:"
    echo "  WORLD_SIZE=$WORLD_SIZE"
    echo "  MASTER_ADDR=$MASTER_ADDR"
    echo "  MASTER_PORT=$MASTER_PORT"
    
    CUDA_VISIBLE_DEVICES=0,1 python train.py \
        --train-data-dir /CelebA_Spoof_train/train \
        --test-data-dir /CelebA_Spoof_train/test \
        --world-size $WORLD_SIZE \
        --dist-url env:// \
        --rank 0 \
        --lr 0.001 \
        --step-size 10 \
        --gamma 0.1 \
        --alpha 0.5 \
        --T 4 \
        --epochs 50 \
        --log-interval 5 \
        --batch-size 256 \
        --save-interval 10
    echo "========== 训练结束 =========="
    echo "时间: $(date)"
    echo "退出状态: $?"
} | tee -a $LOG_FILE 2>&1

echo "训练日志已保存到: $LOG_FILE"

5.2 关键参数说明

这个参数只是一些必要的参数设置,不代表最好。

参数说明默认值
–lr初始学习率0.001
–T蒸馏温度4.0
–alpha蒸馏损失权重0.5
–step-size学习率衰减步长10
–gamma学习率衰减系数0.1

6. 结果分析

训练过程中会输出如下指标:

  • 训练损失/准确率
  • 验证损失/准确率
  • 学习率变化

部分输出信息

Epoch 27/50 - LR: 0.001000
Epoch 27 Batch 0/966 Loss: 0.2262 | Acc: 97.66%
Epoch 27 Batch 200/966 Loss: 0.2267 | Acc: 99.10%
Epoch 27 Batch 400/966 Loss: 0.2267 | Acc: 99.06%
Epoch 27 Batch 600/966 Loss: 0.2266 | Acc: 99.05%
Epoch 27 Batch 800/966 Loss: 0.2265 | Acc: 99.06%
Validation Loss: 0.3762 | Acc: 92.83%
[Epoch 27] 训练Loss: 0.2265 | 训练Acc: 0.99% | 验证Acc: 92.83%
[2025-04-28 16:41:47.486721] 模型已保存至 models/checkpoint_epoch_27.pth

最佳模型会自动保存为models/best_model.pth

7. 总结

本文主要是流水账的方式实现了一个完整的基于知识蒸馏的活体检测训练系统,具有以下特点:

  1. 教师模型提供强大的特征表示能力
  2. 学生模型保持轻量级的同时获得较好性能
  3. 分布式训练大幅加速模型收敛
  4. 自动化日志和模型保存简化实验管理

通过调整蒸馏温度和损失权重,可以进一步优化模型性能。此框架也可方便地扩展到其他视觉分类任务。对于训练参数的调整已经是否加入其他特征进行损失计算,有经验的同学请不吝告知,不尽感激。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值