DeeplabV3+ 训练自己的数据集。pytorch
代码目录数据集构建本实验以landsat8遥感影像海岛对象分割为例。在原始图像(a)的基础上,使用labelme,ps或别的工具,对海岛区域进行标注。标注结果如下图(b)所示:使用代码将标注的海岛区域像素质设置为1,背景区域像素质设置为0,保存为png格式的单通道图像,构建得到label数据,如下图©所示:...
·
一. 代码目录

- dataloaders 存放数据集读取代码
- datasets 存放训练数据
- modeling 存放DeeplabV3+模型文件
- run_lab 保存每次训练的结果参数
- test_result 输出测试结果图
- utils 存放一些工具函数
- train_model.py 为训练主函数
- inference.py 为测试函数
二. 数据集构建
-
本实验以landsat8遥感影像海岛对象分割为例。在原始图像(a)的基础上,使用labelme,ps等工具,对海岛区域进行标注。标注结果如下图(b)所示:
-
使用代码将标注的海岛区域像素质设置为1,背景区域像素质设置为0,保存为png格式的单通道图像,构建得到label数据,如下图©所示:
-
将训练数据集放入train文件夹中,对应的label数据集放入labels文件夹。执行make_train_val_txt_file.py文件,随机选取80%数据作为训练数据,生成train.txt文件;20%数据作为验证数据,生成val.txt文件。
三. 网络训练
- 在 train_model.py 文件中的对网络训练的部分超参数进行配置
parser.add_argument('--network', type=str, default='Deeplab_resnet101',
choices=['Deeplab_resnet101', 'Deeplab_xception', 'Deeplab_drn',
'Deeplab_mobilenet', 'Deeplab_resnet50'],
help='选择用于特征提取的主干网络类型')
parser.add_argument('--loss-type', type=str, default='ce',
choices=['ce', 'focal'],
help='选择损失函数类型,ce为Cross-Entropy交叉熵损失函数。focal_loss用于类别不均衡数据集。')
parser.add_argument('--nclass', type=int, default=2, metavar='N',
help='定义网络是几分类模型,默认为2分类')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='定义网络训练epochs次数,默认为200次')
arser.add_argument('--batch-size', type=int, default=5,
metavar='N', help='定义训练阶段batch-size大小')
parser.add_argument('--lr', type=float, default=0.007, metavar='LR',
help='定义学习率大小默认为0.007')
parser.add_argument('--lr-scheduler', type=str, default='poly',
choices=['poly', 'step', 'cos'],
help='定义学习率优化策略。默认选择为:poly优化策略')
- 本实验将图像分割为海岛区域和背景区域,–nclass默认为2分类
- 选择poly学习率优化策略(–lr-scheduler),学习率大小随着训练epoch次数的增加根据如下公式进行变化:
- 其余超参数设置见代码所示
- 训练过程如下图所示:
learning-rate为当前epoch学习率大小。previous-best为该training中最高mIoU得分,默认保存最高mIoU得分epoch的网络权重参数。每个Validation显示当前epoch的验证集Acc、Acc_class、mIoU、fwIoU得分。 - 训练结果保存在run_lab目录下:
根据超参数–network中的主干网络设置,生成相应的目录文件(Deeplab_resnet101),代码每次training生成对应的experiment_x目录文件,Deeplab_resnet101_model_best.pth为所有experiment中最高mIoU得分的网络权重参数。以第10次训练的experiment_10为例,experiment_10目录中包含:
- best_pred:验证集最高mIoU得分。
- checkpoint.pth:最后一个epoch的网络权重参数
- events.out.tfevents.xx:tensorboard可视化文件
- parameters:当前网络超参数设置
- tensorboard可视化结果:
四. 结果测试
- 将测试图片放入test_result/test目录下,执行inference.py文件,分割结果将会保存在test_result/test_output目录中
- landsat8遥感影像海岛分割结果如图所示。(A)为海岛原图,(B)为Deeplab网络分割结果,(C)为单独使用条件随机场的优化结果:

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)