Salience-DETR 项目复现
说明:官方文档很详细,但是依赖版本更新导致配置问题,可能会导致无法运行
本文包含完整的复现流程,涵盖训练模型和评估和测试
文章目录
0. 如何使用
conda activate salience_detr
cd Salience-DETR/
CUDA_VISIBLE_DEVICES=0,1 accelerate launch test.py --coco-path datasets/ --model-config configs/salience_detr/salience_detr_swin_l_800_1333.py --checkpoint checkpoints/salience_detr_swin_l_800_1333_coco_1x.pth --result result.json --show-dir visualization/
- 指定可见的GPU设备编号为0和1,限制程序仅使用这两块GPU
- accelerate launch是用于简化多GPU/TPU训练的启动命令
- 指定COCO格式数据集的存储路径
- 使用Swin-Large作为backbone
- 指定预训练权重文件路径
- 测试结果将以JSON格式保存(包含检测框坐标、类别、置信度等信息)
- 测试过程中生成的检测结果可视化图片将保存在该目录下
运行结果如下
1.安装步骤
克隆本仓库:
git clone https://github.com/xiuqhou/Salience-DETR.git
cd Salience-DETR/
创建并激活conda环境:
conda create -n salience_detr python=3.8
conda activate salience_detr
根据官方步骤 https://pytorch.org/get-started/locally/ 安装pytorch。本代码要求 python>=3.8, torch>=1.11.0, torchvision>=0.12.0
。
【注意:这里不要使用官方的配置,torchvision == 0.12.0 有问题】
【具体原因和解释见:https://github.com/xiuqhou/Salience-DETR/issues/64】
# conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
安装其他依赖:
conda install --file requirements.txt -c conda-forge
您不需要手动编译CUDA算子,代码第一次运行时会自动编译并加载。
2.准备数据集
请按照如下格式下载 COCO 2017 数据集或准备您自己的数据集,并将他们放在 data/
目录下。您可以使用 tools/visualize_datasets.py
来可视化数据集以验证其正确性。
这个数据放在什么位置都行,记住路径即可
coco/
├── train2017/
├── val2017/
└── annotations/
├── instances_train2017.json
└── instances_val2017.json
可视化例子
conda activate salience_detr
python tools/visualize_datasets.py \
--coco-img datasets/val2017 \
--coco-ann datasets/annotations/instances_val2017.json \
--show-dir visualize_dataset/
目录结构如下
(base) zhanyong@532lab:~/Salience-DETR/datasets$ ll
总计 5780
drwxrwxr-x 7 zhanyong zhanyong 4096 4月 10 13:54 ./
drwxrwxr-x 16 zhanyong zhanyong 4096 4月 10 16:18 ../
drwxr-xr-x 2 root root 4096 4月 10 13:45 annotations/
-rw-rw-r-- 1 zhanyong zhanyong 6082 4月 10 13:25 coco.py
drwxrwxr-x 2 zhanyong zhanyong 4096 4月 10 13:54 __pycache__/
drwxr-xr-x 2 root root 1437696 4月 10 13:46 test2017/
drwxr-xr-x 2 root root 4272128 4月 10 13:50 train2017/
drwxr-xr-x 2 root root 180224 4月 10 13:50 val2017/
3.训练模型
我们使用 accelerate
包来原生处理多GPU训练,您只需要使用 CUDA_VISIBLE_DEVICES
来指定要用于训练的GPU/GPUs。如果未指定,脚本会自动使用机器上所有可用的GPU来训练。
CUDA_VISIBLE_DEVICES=0 accelerate launch main.py # 使用1个GPU进行训练
CUDA_VISIBLE_DEVICES=0,1 accelerate launch main.py # 使用2个GPU进行训练
训练之前请调整 configs/train_config.py
中的参数。
由于我之前是把数据放到了datasets下,所以这里修改成 datasets即可
4.评估和测试
为了使用单个或多个GPU来评估模型,请指定 CUDA_VISIBLE_DEVICES
、dataset
、model
、checkpoint
等参数。
CUDA_VISIBLE_DEVICES=<gpu_ids> accelerate launch test.py --coco-path /path/to/coco --model-config /path/to/model.py --checkpoint /path/to/checkpoint.pth
以下是可选参数,更多参数请查看 test.py 。
--show-dir
: 指定用于保存可视化结果的文件夹路径。--result
: 指定用于保存检测结果的文件路径,必须以.json
结尾。
模型评估的例子
例如,使用8张GPU来在 coco
上评估 salience_detr_resnet50_800_1333
模型,并将检测结果保存至 result.json
文件,并将检测结果的可视化保存至 visualization/
文件夹下,请运行以下命令:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch test.py
--coco-path data/coco \
--model-config configs/salience_detr/salience_detr_resnet50_800_1333.py \
--checkpoint checkpoints/salience_detr_resnet50_800_1333/train/2024-03-22-21_29_56/best_ap.pth \
--result result.json \
--show-dir visualization/
安装之前配置的内容,修改成即可
conda activate salience_detr
CUDA_VISIBLE_DEVICES=0,1 accelerate launch test.py --coco-path datasets/ --model-config configs/salience_detr/salience_detr_swin_l_800_1333.py --checkpoint checkpoints/salience_detr_swin_l_800_1333_coco_1x.pth --result result.json --show-dir visualization/
5.Bug
1. FAILED : ms_deform_attn_cdua.o / gcc: fatal error: cannot execute ‘cc1plus’
未能解决相关问题,好在不影响运行
2. cannot import name ‘MLP’ from ‘torchvision.ops’
如果使用 conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch应该不会出现,可忽略下述内容
在Torchvision 0.12 docs 版本中已经不存在 MLP 需要提升版本
- Torchvision 0.12 docs – No MLP
- Torchvision 0.13 docs – Contains MLP
3. PyTorch 与 torchvision 版本不兼容
Couldn't load custom C++ ops... Please check your PyTorch version with torch.__version__ and your torchvision version with torchvision.__version__...
并且还出现了
undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
这往往意味着二进制接口(ABI)或符号不匹配,通常由不兼容的 PyTorch / torchvision 版本引起。
以下是官方常见版本对应(简化示例),可在 PyTorch 官网 或 torchvision GitHub 查到更多信息:
torch | torchvision | Python |
---|---|---|
main / nightly | main / nightly | >=3.9 , <=3.12 |
2.5 | 0.20 | >=3.9 , <=3.12 |
2.4 | 0.19 | >=3.8 , <=3.12 |
2.3 | 0.18 | >=3.8 , <=3.12 |
2.2 | 0.17 | >=3.8 , <=3.11 |
2.1 | 0.16 | >=3.8 , <=3.11 |
2.0 | 0.15 | >=3.8 , <=3.11 |
torch | torchvision | Python |
---|---|---|
1.13 | 0.14 | >=3.7.2 , <=3.10 |
1.12 | 0.13 | >=3.7 , <=3.10 |
1.11 | 0.12 | >=3.7 , <=3.10 |
1.10 | 0.11 | >=3.6 , <=3.9 |
1.9 | 0.10 | >=3.6 , <=3.9 |
1.8 | 0.9 | >=3.6 , <=3.9 |
1.7 | 0.8 | >=3.6 , <=3.9 |
1.6 | 0.7 | >=3.6 , <=3.8 |
1.5 | 0.6 | >=3.5 , <=3.8 |
1.4 | 0.5 | ==2.7 , >=3.5 , <=3.8 |
1.3 | 0.4.2 / 0.4.3 | ==2.7 , >=3.5 , <=3.7 |
1.2 | 0.4.1 | ==2.7 , >=3.5 , <=3.7 |
1.1 | 0.3 | ==2.7 , >=3.5 , <=3.7 |
<=1.0 | 0.2 | ==2.7 , >=3.5 , <=3.7 |
因此推荐conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch