Qanything 2 0源码解析系列2:上传文件 10648c4b656180e0b747f4c810aa2831

Qanything 2.0源码解析系列2:上传文件

type: Post
status: Published
date: 2024/09/19
summary: Qanything上传文件的逻辑
category: 技术分享

1.x版本中上传文档和解析文档是写在一起的,2.0开始上传文档可以通过接口上传,解析文档是一个单独的服务。

如何启动qanything服务,参考: Qanything 2.0项目部署启动手把手教程

新建知识库源码解析,参考: Qanything 2.0源码解析系列1:新建知识库

sanic_py中定义的上传文件的接口信息:

app.add_route(upload_files, "/api/local_doc_qa/upload_files", methods=['POST']) # tags=["上传文件"]

upload_files 方法定义在handler.py中。

📝 逐行代码解析

upload_files方法如下,代码比较多,在1.x中核心代码就一行local_doc_qa.insert_files_to_milvus(user_id, kb_id, local_files),2.0中还保留着这一行代码,只是被注释起来了。

@get_time_async
async def upload_files(req: request):
		# local_doc_qa这个是在app.run启动之前初始化好的一个全局变量,是LocalDocQA类的类对象,是Qanything的核心类。
    local_doc_qa: LocalDocQA = req.app.ctx.local_doc_qa
    
    # 拿到请求参数传递的user_id的值
    user_id = safe_get(req, 'user_id')
    
    # 拿到请求参数传递的user_info的值,不传默认是1234
    user_info = safe_get(req, 'user_info', "1234")
    
    '''
    检查user_id和user_info的有效性
    1. user_id、user_info不能是None
    2. user_id的长度不能超过64,必须是字符串类型,只能是数字、字母、下划线的组合,且必须以字母开头
    3. user_info必须是纯数字
    '''
    passed, msg = check_user_id_and_user_info(user_id, user_info)
    if not passed:
        return sanic_json({"code": 2001, "msg": msg})
        
    # 将user_id和user_info拼成一个新user_id
    user_id = user_id + '__' + user_info
    debug_logger.info("upload_files %s", user_id)
    debug_logger.info("user_info %s", user_info)
    
    # 拿到请求参数传递的kb_id的值
    kb_id = safe_get(req, 'kb_id')
    
    '''
    kb_id纠正, 就是在这个kb_id后面拼接一个KB_SUFFIX,如果有这个KB_SUFFIX,直接返回kb_id,如果不存在,那么:
    1. 如果kb_id以"_FAQ"结尾,比如KBc86eaa3f278f4ef9908780e8e558c6eb_FAQ,那么在kb_id和FAQ之间增加一个KB_SUFFIX,默认是_240625, 最终变成了如KBc86eaa3f278f4ef9908780e8e558c6eb_240625_FAQ
    2. 如果kb_id不以"_FAQ"结尾,直接在kb_id后面增加一个KB_SUFFIX
    3. 返回新的kb_id
    '''
    kb_id = correct_kb_id(kb_id)
    debug_logger.info("kb_id %s", kb_id)
    
    # 拿到请求参数传递的mode的值
    mode = safe_get(req, 'mode', default='soft')  # soft代表不上传同名文件,strong表示强制上传同名文件
    debug_logger.info("mode: %s", mode)
    
    # 拿到请求参数传递的chunk_size的值,默认是800
    chunk_size = safe_get(req, 'chunk_size', default=DEFAULT_PARENT_CHUNK_SIZE)
    debug_logger.info("chunk_size: %s", chunk_size)
    
    # 拿到请求参数传递的use_local_file的值,默认是false。
    use_local_file = safe_get(req, 'use_local_file', 'false')
    
    # 如果use_local_file == 'true',默认会遍历“项目根目录/data”下的所有文件,如果是false,则处理请求参数中传递过来的文件。所以一般都是false
    if use_local_file == 'true':
        files = read_files_with_extensions()
    else:
        files = req.files.getlist('files')
    debug_logger.info(f"{user_id} upload files number: {len(files)}")
    
    # 从mysql qanything数据库 KnowledgeBase表中 查看这个user_id对应的kb_id是否存在。不要被milvus_summary迷惑了,其实是mysql_client对象
    not_exist_kb_ids = local_doc_qa.milvus_summary.check_kb_exist(user_id, [kb_id])
    if not_exist_kb_ids:
        msg = "invalid kb_id: {}, please check...".format(not_exist_kb_ids)
        return sanic_json({"code": 2001, "msg": msg, "data": [{}]})
		
		# 从File表中查一下这个kb_id下面已经存在多少个文件了,返回这些文件。
    exist_files = local_doc_qa.milvus_summary.get_files(user_id, kb_id)
    if len(exist_files) + len(files) > 10000:
        return sanic_json({"code": 2002,
                           "msg": f"fail, exist files is {len(exist_files)}, upload files is {len(files)}, total files is {len(exist_files) + len(files)}, max length is 10000."})
		
		'''
		定义三个局部变量 
		1. data没啥大用,用于返回请求结果,主要是可以看到file_id
		2. 每个文件都会生成一个LocalFile对象,存放在local_files列表中
		3. file_names用于存放处理好的文件名
		'''
    data = []
    local_files = []
    file_names = []
    
    # 遍历files,其实就是遍历上传的文件,这个遍历的目的是处理文件名,有些文件名乱七八糟的。只要你的文件名规规矩矩的,不要有一些特殊符号,这里不会修改你的文件名的,执行完还是原来的文件名。
    for file in files:
        if isinstance(file, str):
            file_name = os.path.basename(file)
        else:
            debug_logger.info('ori name: %s', file.name)
            file_name = urllib.parse.unquote(file.name, encoding='UTF-8')
            debug_logger.info('decode name: %s', file_name)
        # # 使用正则表达式替换以%开头的字符串
        # file_name = re.sub(r'%\w+', '', file_name)
        # 删除掉全角字符
        file_name = re.sub(r'[\uFF01-\uFF5E\u3000-\u303F]', '', file_name)
        debug_logger.info('cleaned name: %s', file_name)
        # max_length = 255 - len(construct_qanything_local_file_nos_key_prefix(file_id)) == 188
        file_name = truncate_filename(file_name, max_length=110)
        file_names.append(file_name)

    exist_file_names = []
    if mode == 'soft':
		    # 从File表中查询kb_id和file_names同时存在的记录
        exist_files = local_doc_qa.milvus_summary.check_file_exist_by_name(user_id, kb_id, file_names)
        
        # f[1]对应的是文件名
        exist_file_names = [f[1] for f in exist_files]
        
        # 这个循环就是打日志了
        for exist_file in exist_files:
            file_id, file_name, file_size, status = exist_file
            debug_logger.info(f"{file_name}, {status}, existed files, skip upload")
            # await post_data(user_id, -1, file_id, status, msg='existed files, skip upload')

    now = datetime.now()
    timestamp = now.strftime("%Y%m%d%H%M")

    failed_files = []
    for file, file_name in zip(files, file_names):
        if file_name in exist_file_names:
            continue
        
        # 初始化一个LocalFile对象,LocalFile也是Qanything的核心类。具体请看下文的介绍。
        local_file = LocalFile(user_id, kb_id, file, file_name)
        
        # 快速估算文件的字符数,返回字符数。图片类型无法统计字符数,直接返回True
        chars = fast_estimate_file_char_count(local_file.file_location)
        
        # MAX_CHARS值是一百万,一般不会超过
        if chars > MAX_CHARS:
            debug_logger.warning(f"fail, file {file_name} chars is {chars}, max length is {MAX_CHARS}.")
            # return sanic_json({"code": 2003, "msg": f"fail, file {file_name} chars is too much, max length is {MAX_CHARS}."})
            failed_files.append(file_name)
            continue
        
        # 变量取值
        file_id = local_file.file_id
        file_size = len(local_file.file_content)
        file_location = local_file.file_location
        
        # 将LocalFile对象放到一个列表里,1.x版本中后续会遍历这个local_file,然后解析文件,2.0版本中这个变量没啥用。
        local_files.append(local_file)
        
        # 在File表中插入一条记录
        msg = local_doc_qa.milvus_summary.add_file(file_id, user_id, kb_id, file_name, file_size, file_location,
                                                   chunk_size, timestamp)
        debug_logger.info(f"{file_name}, {file_id}, {msg}")
        data.append(
            {"file_id": file_id, "file_name": file_name, "status": "gray", "bytes": len(local_file.file_content),
             "timestamp": timestamp, "estimated_chars": chars})
		
		# 核心方法,将文件内容embedding后,存入milvus向量数据库中,在1.x中会执行下面注释掉的这行代码,但是2.0之后解析文件的逻辑是一个单独的服务。
    # asyncio.create_task(local_doc_qa.insert_files_to_milvus(user_id, kb_id, local_files))
    if exist_file_names:
        msg = f'warning,当前的mode是soft,无法上传同名文件{exist_file_names},如果想强制上传同名文件,请设置mode:strong'
    elif failed_files:
        msg = f"warning, {failed_files} chars is too much, max characters length is {MAX_CHARS}, skip upload."
    else:
        msg = "success,后台正在飞速上传文件,请耐心等待"
    return sanic_json({"code": 200, "msg": msg, "data": data})

LocalFile类

1.x版本的LocalFile写的很多方法,2.0了就一个__init__

class LocalFile:
		# 接收四个参数 user_id kb_id file file_name
    def __init__(self, user_id, kb_id, file: Union[File, str, Dict], file_name):
        self.user_id = user_id
        self.kb_id = kb_id
        
        # 随机初始化一个file_id
        self.file_id = uuid.uuid4().hex
        self.file_name = file_name
        self.file_url = ''
        if isinstance(file, Dict):
            self.file_location = "FAQ"
            self.file_content = b''
        elif isinstance(file, str):
            self.file_location = "URL"
            self.file_content = b''
            self.file_url = file
        else:
		        # 拿到文件的二进制内容
            self.file_content = file.body
            
            # upload_path = '根目录/QANY_DB/content/{user_id}'
            upload_path = os.path.join(UPLOAD_ROOT_PATH, user_id)
            
            # file_dir = 根目录/QANY_DB/content/{user_id}/{kb_id}/{file_id}
            file_dir = os.path.join(upload_path, self.kb_id, self.file_id)
            
            # 文件夹不存在就创建一个
            os.makedirs(file_dir, exist_ok=True)
            
            # 文件保存在本地目录中
            self.file_location = os.path.join(file_dir, self.file_name)
            #  如果文件不存在:
            if not os.path.exists(self.file_location):
                with open(self.file_location, 'wb') as f:
                    f.write(self.file_content)

🤗 总结归纳

1.x上传文件之后并对文件进行处理存milvus了,2.0之后文件解析是一个单独的服务,这里仅仅是将相关记录存入了数据库。后续通过文件解析服务再从数据库取值根据不同的文件类型进行解析。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值