AUTOVC 代码解析 —— main.py
简介
本项目一个基于 AUTOVC 模型的语音转换项目,它是使用 PyTorch 实现的(项目地址)。
AUTOVC 遵循自动编码器框架,只对自动编码器损耗进行训练,但它引入了精心调整的降维和时间下采样来约束信息流,这个简单的方案带来了显著的性能提高。(详情请参阅 AUTOVC 的详细介绍)。
由于 AUTOVC 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
本文将介绍项目中的 main.py 文件:网络训练主函数
函数解析
str2bool
该函数的作用是: 将字符串形式的 ‘True’ 与 ‘False’ 转换为布尔型的 True 与 False 。
输入参数:
v : 需要改变类型的字符
输出参数:
v.lower() in ('true') : True 或 False
代码详解:
def str2bool(v):
# 将字符串中的字符全部小写
# 判断转换后的字符是否为字符串 'true' 的子串,并返回
return v.lower() in ('true')
main
该函数的作用是: 加载数据,配置网络模型,训练模型
输入参数:
config : 网路模型配置
输出参数: 无
代码详解:
def main(config):
# 快速训练
# 如果网络的输入数据维度或类型上变化不大,设置 torch.backends.cudnn.benchmark = true 可以增加运行效率
cudnn.benchmark = True
# 数据迭代器
vcc_loader = get_loader(config.data_dir, config.batch_size, config.len_crop)
# 构建网络模型
solver = Solver(vcc_loader, config)
# 训练网络模型
solver.train()
if __ name __ == ‘__ main __’:
该函数的作用是: 处理参数,组合成网络配置参数,调用主函数训练网络
输入参数: 无
输出参数: 无
代码详解:
if __name__ == '__main__':
# 创建解析器
parser = argparse.ArgumentParser()
#模型配置
parser.add_argument('--lambda_cd', type=float, default=1, help='隐藏编码损失的权重')
# 内容编码长度
parser.add_argument('--dim_neck', type=int, default=16)
# 说话人编码长度
parser.add_argument('--dim_emb', type=int, default=256)
# 后置网络输出长度
parser.add_argument('--dim_pre', type=int, default=512)
# 时间采样系数
parser.add_argument('--freq', type=int, default=16)
# 训练配置
# 说话文件--梅尔数据文件目录
parser.add_argument('--data_dir', type=str, default='./spmel')
# 批大小
parser.add_argument('--batch_size', type=int, default=2, help='mini-batch size')
# 训练次数
parser.add_argument('--num_iters', type=int, default=1000000, help='number of total iterations')
# 截取长度
parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length')
# 杂项
# 日志打印间隔
parser.add_argument('--log_step', type=int, default=10)
# 将上述配置组装为网络模型配置
config = parser.parse_args()
# 打印网络模型配置
print(config)
# 调用主函数,准备训练
main(config)