标注文件为txt格式,文件中单行为一条标注记录,将其按照比例划分成为训练集和验证集
训练指令:python divide.py --path [your file path] --train_ratio 0.8
程序运行完以后,会在当前文件所在统计目录下生成trian.txt、和val.txt两个文件
import argparse
import numpy as np
def divide(lines, train_ratio):
fp_val = open('./val.txt', 'w+', encoding='utf-8')
fp_train = open('./train.txt', 'w+', encoding='utf-8')
length = len(lines)
shuffled_indices = np.random.permutation(length)
train_size = int(length * train_ratio)
train_list = shuffled_indices[:train_size]
val_list = shuffled_indices[train_size:]
for i in val_list:
fp_val.write(lines[i])
for i in train_list:
fp_train.write(lines[i])
print('train images ', train_size)
print('val images ', length - train_size)
fp_val.close()
fp_train.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, default='./file.txt')
parser.add_argument('--train_ratio', type=float, default=0.8)
args = parser.parse_args()
fp = open(args.path, 'r', encoding='utf-8')
lines = fp.readlines()
divide(lines, args.train_ratio)