针对夸夸闲聊数据集,利用UniLM模型进行模型训练及测试,更深入地了解预训练语言模型的使用方法,完成一个生成式闲聊机器人任务。
项目主要结构如下:
- data 存放数据的文件夹
- dirty_word.txt 敏感词数据
- douban_kuakua_qa.txt 原始语料 【数据量:大概20M的样子】==》用于增量训练
- sample.json 处理后的语料样例
- kuakua_robot_model 已训练好的模型路径
- config.json
- pytorch_model.bin
- vocab.txt
- pretrain_model UniLm预训练文件路径
- config.json
- pytorch_model.bin
- vocab.txt
- chatbot.py 模型推理文件
- configuration_unilm.py UniLm配置文件
- data_helper.py 数据预处理文件
- data_set.py 数据类文件
- modeling_unilm.py UniLm模型文件
- train.py 模型训练文件
- dirty_recognize.py 敏感词检测文件
增量训练的数据样例:
1 Q: 要去打球赛了求表扬
2 A: 真棒好好打乒乓球!
3 Q: 要去打球赛了求表扬
4 A: 是篮球哈哈哈
5 Q: 要去打球赛了求表扬
6 A: 篮板王就是你!
7 Q: 要去打球赛了求表扬
8 A: 加油别把鞋踢脏喽
9 Q: 要去打球赛了求表扬
10 A: 多买点儿币!
11 Q: 要去打球赛了求表扬
12 A: 已经脏了
13 Q: 要去打球赛了求表扬
14 A: 好滴
15 Q: 要去打球赛了求表扬
16 A: 这个配色是是真心不太合我的胃口,还有为什么白鞋要配黑袜子
17 Q: 要去打球赛了求表扬
18 A: 这不是表扬组吗hhh你咋来拆台
19 Q: 要去打球赛了求表扬
20 A: 我不是,我没有,别瞎说哈
21 Q: 要去打球赛了求表扬
22 A: 全场最帅(・ัω・ั),卡胃踩脚拇指戳肋骨无毒神掌天下无敌,然后需要代打嘛
23 Q: 要去打球赛了求表扬
24 A: 你走!
25 Q: 要去打球赛了求表扬
26 A: 8要!
27 Q: 要去打球赛了求表扬
28 A: 我不,我还想问问什么鞋码,多高多重,打什么位置的
注意:由于GitHub不方便放模型文件,因此data文件中douban_kuakua_qa.txt文件、kuakua_robot_model文件夹和pretrain_model文件夹中的模型bin文件,请从百度云盘中下载。【bert模型大小:400MB,用于增量训练的模型,应该是来自https://huggingface.co/bert-base-chinese/tree/main下载的原始bert文件】
文件名称 | 下载地址 | 提取码 |
---|---|---|
pretrain_model | 百度云 | 7h4a |
kuakua_robot_model | 百度云 | j954 |
data | 百度云 | 3sz3 |
由于敏感词表中包含大量敏感词,导致百度云的data链接会失效,因此将敏感词之间放到项目的data目录下。
环境配置
模型训练或推理所需环境,请参考requirements.txt文件。
数据处理
数据预处理需要运行data_helper.py文件,会在data文件夹中生成训练集和测试集文件。
命令如下:
python3 data_helper.py
注意:如果需要修改数据生成路径或名称,请修改data_helper.py文件147-150行,自行定义。
模型训练
模型训练需要运行train.py文件,会自动生成output_dir文件夹,存放每个epoch保存的模型文件。
命令如下:
python3 train.py --device 0 \
--data_dir "data/" \
--src_file "train.json" \
--model_name_or_path "pretrain_model/" \
--max_seq_length 256 \
--train_batch_size 16 \
--num_train_epochs 10
注意:当服务器资源不同或读者更换数据等时,可以在模型训练时修改响应参数,详细参数说明见代码或阅读书3.5.4小节。
模型训练示例如下:
模型推理
模型训练需要运行chatbot.py文件,可以采用项目中以提供的模型,也可以采用自己训练后的模型。
命令如下:
python3 chatbot.py --device 0 --topk 3 --topp 0.95 --max_len 32
注意:如果修改模型路径,请修改–model_path参数。
USER:工作不开心
CHAT-BOT:我也是,不过现在开始努力找工作啦!一起加油鸭~
USER:工作不开心
CHAT-BOT:那就祝福你,能找到喜欢的工作吧。
USER:工作不开心
CHAT-BOT:你也要加油鸭!
USER:工作不开心
CHAT-BOT:我也是,不过总有一天就会好起来的
样例2
USER:我分手了
CHAT-BOT:分手快乐,你会遇到更好的。
USER:我分手了
CHAT-BOT:旧的不去新的不来,你会遇见更好的
USER:我分手了
CHAT-BOT:分的好,下一个更乖
样例3
USER:好想去吃火锅,但是我在减肥
CHAT-BOT:你好棒!一定会瘦下来的!
USER:好像去吃火锅,但是我在减肥
CHAT-BOT:吃火锅是为了能更好的减肥,你真棒!
我自己没有gpu,训练了一天,运行的流程如下:
D:\book\ChatGPTBook``-``main\UniLMProj 的目录``2023``-``09``-``25` `17``:``54` `<``DIR``> .``2023``-``09``-``25` `17``:``22` `<``DIR``> ..``2023``-``09``-``25` `17``:``22` `5``,``530` `chatbot.py``2023``-``09``-``25` `17``:``22` `2``,``153` `configuration_unilm.py``2023``-``09``-``25` `21``:``04` `<``DIR``> data``2023``-``09``-``25` `17``:``22` `4``,``375` `data_helper.py``2023``-``09``-``25` `17``:``22` `9``,``173` `data_set.py``2023``-``09``-``25` `17``:``22` `1``,``304` `dirty_recognize.py``2023``-``09``-``25` `17``:``22` `<``DIR``> images``2023``-``09``-``25` `17``:``22` `<``DIR``> kuakua_robot_model``2023``-``09``-``25` `17``:``22` `13``,``452` `modeling_unilm.py``2023``-``09``-``25` `17``:``22` `<``DIR``> pretrain_model``2023``-``09``-``25` `17``:``22` `4``,``199` `README.md``2023``-``09``-``25` `17``:``22` `88` `requirements.txt``2023``-``09``-``25` `17``:``22` `8``,``337` `train.py``2023``-``09``-``25` `17``:``22` `1``,``861` `trie.py``2023``-``09``-``25` `17``:``54` `<``DIR``> __pycache__`` ``10` `个文件 ``50``,``472` `字节`` ``7` `个目录 ``175``,``152``,``689``,``152` `可用字节` `D:\book\ChatGPTBook``-``main\UniLMProj>python data_helper.py``total number of data: ``121687` `D:\book\ChatGPTBook``-``main\UniLMProj>``dir` `data`` ``驱动器 D 中的卷是 Data`` ``卷的序列号是 CA99``-``555E` ` ``D:\book\ChatGPTBook``-``main\UniLMProj\data 的目录` `2023``-``09``-``25` `21``:``06` `<``DIR``> .``2023``-``09``-``25` `17``:``54` `<``DIR``> ..``2023``-``09``-``25` `17``:``22` `245``,``546` `dirty_words.txt``2023``-``09``-``25` `17``:``56` `21``,``620``,``763` `douban_kuakua_qa.txt``2023``-``09``-``25` `17``:``22` `446` `sample.json``2023``-``09``-``25` `21``:``06` `14``,``272``,``447` `train.json`` ``4` `个文件 ``36``,``139``,``202` `字节`` ``2` `个目录 ``175``,``138``,``414``,``592` `可用字节` `D:\book\ChatGPTBook``-``main\UniLMProj>python train.py``Traceback (most recent call last):`` ``File` `"D:\book\ChatGPTBook-main\UniLMProj\train.py"``, line ``18``, ``in` `<module>`` ``import` `torch``ModuleNotFoundError: No module named ``'torch'` `D:\book\ChatGPTBook``-``main\UniLMProj>pip install torch`` ` `D:\book\ChatGPTBook``-``main\UniLMProj>pip install torch`` ` `D:\book\ChatGPTBook``-``main\UniLMProj>``D:\book\ChatGPTBook`