bar

前言

Hello,大家好,我是GISer Liu😁,一名热爱AI技术的GIS开发者,这个LLM开发基础阶段已经进入尾声了,本文中我们不介绍更多的理论与知识点,而是通过的分析开源项目的解决方案来帮助各位开发者理清自己开发的思路;

在本文中作者将通过分析开源项目 个人知识库助手:学习这个RAG应用的开发流程,思路以及业务代码;帮助读者能学会如何规划自己的LLM的应用开发思路;


一、个人知识库助手

1.项目介绍

①背景

在当今数据量迅速增长的时代,高效管理和检索信息已成为关键技能。为了应对这一挑战以及LLM技术的发展,该项目应运而生,旨在构建一个基于 Langchain 的个人知识库助手。该助手可以通过高效的信息管理系统和强大的检索功能,为用户提供了一个可靠的信息获取平台。

②目标意义
  1. 核心目标:充分利用大型语言模型在处理自然语言查询方面的优势,并进行定制化开发以满足用户需求,从而实现对复杂信息的智能理解和精确回应
  2. 在项目开发过程中,团队深入分析了大型语言模型的潜力和局限,特别是其生成幻觉信息的倾向。
  3. 为了解决幻觉信息的问题,项目集成了 RAG 技术,这是一种结合检索和生成的方法
③主要功能
  • 信息检索:从大规模数据库或知识库中提取相关信息。快速定位和获取精确的内容。
  • 生成式问答:利用检索到的信息生成自然语言回答。提供详细和上下文相关的回答,提高用户体验。
  • 知识更新:不断更新和扩充知识库,保持信息的时效性准确性
  • 。支持多种数据源的集成和管理。
  • 多领域适用:支持在不同领域和主题下的应用,如技术文档、医学知识、法律咨询等。灵活定制以适应特定行业需求。
  • 用户交互:提供自然语言界面,便于用户查询和获取信息。提高用户的搜索效率和满意度。

2.项目部署

①环境要求
  • CPU: Intel 5代处理器(云CPU方面,建议选择 2 核以上的云CPU服务);

阿里云199一年的都可以,也可以自己本机部署;因为是API调用,对电脑性能要求不高;

  • 内存(RAM): 至少 4 GB

  • 操作系统:Windows、macOS、Linux均可

②部署流程

这里作者将整个流程打包:

git clone https://github.com/logan-zou/Chat_with_Datawhale_langchain.git
cd Chat_with_Datawhale_langchain
# 创建 Conda 环境
conda create -n llm-universe python==3.9.0
# 激活 Conda 环境
conda activate llm-universe
# 安装依赖项
pip install -r requirements.txt

运行项目:

# Linux 系统
cd serve
uvicorn api:app --reload 

# Windows 系统
cd serve
python api.py

或者:

python run_gradio.py -model_name='模型名称' -embedding_model='嵌入模型编号' -db_path='知识库文件路径' -persist_path='持久化目录文件路径'

这里记得配置自己的API Key
在这里插入图片描述

③核心思想

本项目其实是针对四种大模型 API 实现了底层封装基于 Langchain 搭建了可切换模型的检索问答链,并实现 API 以及 Gradio 部署的个人轻量大模型应用

④技术栈

技术栈

  • LLM层:统一封装了四个大模型,用作底层模型进行调用。
  • 数据层:通过选择的Embedding模型API进行向量数据库的创建和向量检索。源数据经过Embedding处理后可以被向量数据库使用。
  • 数据库层:基于个人知识库源数据搭建的向量数据库。在本项目中,我们选择了Chroma,当然Faiss也不错。
  • 应用层:确定我们的应用有哪些?RAG、工具使用、RPA等将LLM和具体业务结合使用更佳。在本项目中,我们仅使用RAG,基于LangChain提供的检索问答链基类进行了进一步封装,从而支持不同模型切换以及便捷实现基于数据库的检索问答。
  • 服务层:本项目基于FastAPIGradio。对于后端,我们使用FastAPI即可,无需改变。前端如果只能使用Python开发,使用GradioStreamlit都是快速不错的选择。如果是全栈开发者或企业开发,使用VueReact可以开发出更专业且美观的应用。

当然,这里是因为我们的应用单一,因此只选择了向量数据库。如果还具有其他业务,如列表信息、历史记录、检索信息等,使用结构化数据库也是必要的。向量数据库和结构化数据混合使用更合适。

  • 本项目支持本地M3E Embedding模型和 API Embedding结合的方式进行向量化;

3.应用详解

①业务流程

1、核心架构

llm-universe 个人知识库助手是一个典型的 RAG 项目,通过 langchain+LLM 实现本地知识库问答,建立了全流程可使用开源模型实现的本地知识库对话应用。该项目当前已经支持使用 ChatGPT,星火 Spark 模型,文心大模型,智谱 GLM 等大语言模型的接入。
rag-process

作者这里绘制了一个流程图可以看看:😺😺

整个 RAG 过程包括以下操作:

  1. 用户提出问题 (Query)
  2. 加载和读取知识库文档
  3. 对知识库文档进行分割
  4. 分割后的知识库文本向量化并存入向量库建立索引
  5. 用户提问(Query) 向量化
  6. 在知识库文档向量中匹配出与问句 (Query) 向量最相似的 top k 个
  7. 匹配出的知识库文本作为上下文 (Context) 和问题一起添加到 prompt 中
  8. 提交给 LLM 生成回答 (response)

这里很好理解,之前文章中,LLM基础学习的第三篇详细讲解了其中的每个步骤;

②索引 Index

三步走战略:

  • 文本数据加载和读取
  • 文本数据分割
  • 文本数据向量化

这和我们之前向量数据库搭建的过程差不多;详细内容如下:

(1)知识库数据加载和读取

该项目使用的是:
《机器学习公式详解》PDF版本
《面向开发者的 LLM 入门教程 第一部分 Prompt Engineering》md版本
《强化学习入门指南》MP4版本
以及datawhale总仓库所有开源项目的readme:https://github.com/datawhalechina
data

大家可以根据自己的实际业务选择自己的需要的数据,数据存放在 ../../data_base/knowledge_db 目录下,用户可以将自己的文件存放到这里;

项目官方提供了拉去官网readme的爬虫:

import json
import requests
import os
import base64
import loguru
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 从环境变量中获取TOKEN
TOKEN = os.getenv('TOKEN')
# 定义获取组织仓库的函数
def get_repos(org_name, token, export_dir):
    headers = {
        'Authorization': f'token {token}',
    }
    url = f'https://api.github.com/orgs/{org_name}/repos'
    response = requests.get(url, headers=headers, params={'per_page': 200, 'page': 0})
    if response.status_code == 200:
        repos = response.json()
        loguru.logger.info(f'Fetched {len(repos)} repositories for {org_name}.')
        # 使用 export_dir 确定保存仓库名的文件路径
        repositories_path = os.path.join(export_dir, 'repositories.txt')
        with open(repositories_path, 'w', encoding='utf-8') as file:
            for repo in repos:
                file.write(repo['name'] + '\n')
        return repos
    else:
        loguru.logger.error(f"Error fetching repositories: {response.status_code}")
        loguru.logger.error(response.text)
        return []
# 定义拉取仓库README文件的函数
def fetch_repo_readme(org_name, repo_name, token, export_dir):
    headers = {
        'Authorization': f'token {token}',
    }
    url = f'https://api.github.com/repos/{org_name}/{repo_name}/readme'
    response = requests.get(url, headers=headers)
    if response.status_code == 200:
        readme_content = response.json()['content']
        # 解码base64内容
        readme_content = base64.b64decode(readme_content).decode('utf-8')
        # 使用 export_dir 确定保存 README 的文件路径
        repo_dir = os.path.join(export_dir, repo_name)
        if not os.path.exists(repo_dir):
            os.makedirs(repo_dir)
        readme_path = os.path.join(repo_dir, 'README.md')
        with open(readme_path, 'w', encoding='utf-8') as file:
            file.write(readme_content)
    else:
        loguru.logger.error(f"Error fetching README for {repo_name}: {response.status_code}")
        loguru.logger.error(response.text)
# 主函数
if __name__ == '__main__':
    # 配置组织名称
    org_name = 'datawhalechina'
    # 配置 export_dir
    export_dir = "../../database/readme_db"  # 请替换为实际的目录路径
    # 获取仓库列表
    repos = get_repos(org_name, TOKEN, export_dir)
    # 打印仓库名称
    if repos:
        for repo in repos:
            repo_name = repo['name']
            # 拉取每个仓库的README
            fetch_repo_readme(org_name, repo_name, TOKEN, export_dir)
    # 清理临时文件夹
    # if os.path.exists('temp'):
    #     shutil.rmtree('temp')

以上默认会把这些readme文件放在同目录database下的readme_db文件。其中这些readme文件含有不少无关信息;

😏再运行database/text_summary_readme.py文件可以调用大模型生成每个readme文件的摘要并保存到上述知识库目录/data_base/knowledge_db /readme_summary文件夹中,。代码如下:

import os
from dotenv import load_dotenv
import openai
from test_get_all_repo import get_repos
from bs4 import BeautifulSoup
import markdown
import re
import time
# Load environment variables
load_dotenv()
TOKEN = os.getenv('TOKEN')
# Set up the OpenAI API client
openai_api_key = os.environ["OPENAI_API_KEY"]

# 过滤文本中链接防止大语言模型风控
def remove_urls(text):
    # 正则表达式模式,用于匹配URL
    url_pattern = re.compile(r'https?://[^\s]*')
    # 替换所有匹配的URL为空字符串
    text = re.sub(url_pattern, '', text)
    # 正则表达式模式,用于匹配特定的文本
    specific_text_pattern = re.compile(r'扫描下方二维码关注公众号|提取码|关注|科学上网|回复关键词|侵权|版权|致谢|引用|LICENSE'
                                       r'|组队打卡|任务打卡|组队学习的那些事|学习周期|开源内容|打卡|组队学习|链接')
    # 替换所有匹配的特定文本为空字符串
    text = re.sub(specific_text_pattern, '', text)
    return text

# 抽取md中的文本
def extract_text_from_md(md_content):
    # Convert Markdown to HTML
    html = markdown.markdown(md_content)
    # Use BeautifulSoup to extract text
    soup = BeautifulSoup(html, 'html.parser')

    return remove_urls(soup.get_text())

def generate_llm_summary(repo_name, readme_content,model):
    prompt = f"1:这个仓库名是 {repo_name}. 此仓库的readme全部内容是: {readme_content}\
               2:请用约200以内的中文概括这个仓库readme的内容,返回的概括格式要求:这个仓库名是...,这仓库内容主要是..."
    openai.api_key = openai_api_key
    # 具体调用
    messages = [{"role": "system", "content": "你是一个人工智能助手"},
                {"role": "user", "content": prompt}]
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
    )
    return response.choices[0].message["content"]

def main(org_name,export_dir,summary_dir,model):
    repos = get_repos(org_name, TOKEN, export_dir)

    # Create a directory to save summaries
    os.makedirs(summary_dir, exist_ok=True)

    for id, repo in enumerate(repos):
        repo_name = repo['name']
        readme_path = os.path.join(export_dir, repo_name, 'README.md')
        print(repo_name)
        if os.path.exists(readme_path):
            with open(readme_path, 'r', encoding='utf-8') as file:
                readme_content = file.read()
            # Extract text from the README
            readme_text = extract_text_from_md(readme_content)
            # Generate a summary for the README
            # 访问受限,每min一次
            time.sleep(60)
            print('第' + str(id) + '条' + 'summary开始')
            try:
                summary = generate_llm_summary(repo_name, readme_text,model)
                print(summary)
                # Write summary to a Markdown file in the summary directory
                summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary.md")
                with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
                    summary_file.write(f"# {repo_name} Summary\n\n")
                    summary_file.write(summary)
            except openai.OpenAIError as e:
                summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary风控.md")
                with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
                    summary_file.write(f"# {repo_name} Summary风控\n\n")
                    summary_file.write("README内容风控。\n")
                print(f"Error generating summary for {repo_name}: {e}")
                # print(readme_text)
        else:
            print(f"文件不存在: {readme_path}")
            # If README doesn't exist, create an empty Markdown file
            summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary不存在.md")
            with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
                summary_file.write(f"# {repo_name} Summary不存在\n\n")
                summary_file.write("README文件不存在。\n")
if __name__ == '__main__':
    # 配置组织名称
    org_name = 'datawhalechina'
    # 配置 export_dir
    export_dir = "../database/readme_db"  # 请替换为实际readme的目录路径
    summary_dir="../../data_base/knowledge_db/readme_summary"# 请替换为实际readme的概括的目录路径
    model="gpt-3.5-turbo"  #deepseek-chat,gpt-3.5-turbo,moonshot-v1-8k
    main(org_name,export_dir,summary_dir,model)
  • extract_text_from_md() 函数用来抽取 md 文件中的文本
  • remove_urls() 函数过滤网页链接以及大模型风控词
  • generate_llm_summary() 函数LLM生成每个 readme 的概括

2.在上述知识库构建完毕之后,../../data_base/knowledge_db 目录下就有了目标文件:
file

上面 函数本质上起了一个数据爬虫和数据清理的过程;

作者这里也贡献几个构建个人GIS知识库的爬虫代码:

  • 爬取域名下的所有文字数据
# 爬取域名下的所有文字数据
import requests
import re
import urllib.request
from bs4 import BeautifulSoup
from collections import deque
from html.parser import HTMLParser
from urllib.parse import urlparse
import os
import pandas as pd
import tiktoken
import openai
import numpy as np
from ast import literal_eval
# 正则表达式模式,用于匹配URL
HTTP_URL_PATTERN = r'^http[s]{0,1}://.+$'
# 定义 OpenAI 的 API 密钥
openai.api_key = 'your api key'
# 定义要爬取的根域名
domain = "leafletjs.com"
full_url = "https://leafletjs.com/"
# 创建一个类来解析 HTML 并获取超链接
class HyperlinkParser(HTMLParser):
    def __init__(self):
        super().__init__()
        # 创建一个列表来存储超链接
        self.hyperlinks = []
    # 重写 HTMLParser 的 handle_starttag 方法以获取超链接
    def handle_starttag(self, tag, attrs):
        attrs = dict(attrs)
        # 如果标签是锚点标签且具有 href 属性,则将 href 属性添加到超链接列表中
        if tag == "a" and "href" in attrs:
            self.hyperlinks.append(attrs["href"])
# 函数:从 URL 获取超链接
def get_hyperlinks(url):
    # 尝试打开 URL 并读取 HTML
    try:
        # 打开 URL 并读取 HTML
        with urllib.request.urlopen(url) as response:
            # 如果响应不是 HTML,则返回空列表
            if not response.info().get('Content-Type').startswith("text/html"):
                return []         
            # 解码 HTML
            html = response.read().decode('utf-8')
    except Exception as e:
        print(e)
        return []
    # 创建 HTML 解析器,然后解析 HTML 以获取超链接
    parser = HyperlinkParser()
    parser.feed(html)
    return parser.hyperlinks
# 函数:获取在同一域内的 URL 的超链接
def get_domain_hyperlinks(local_domain, url):
    clean_links = []
    for link in set(get_hyperlinks(url)):
        clean_link = None
        # 如果链接是 URL,请检查是否在同一域内
        if re.search(HTTP_URL_PATTERN, link):
            # 解析 URL 并检查域是否相同
            url_obj = urlparse(link)
            if url_obj.netloc == local_domain:
                clean_link = link

        # 如果链接不是 URL,请检查是否是相对链接
        else:
            if link.startswith("/"):
                link = link[1:]
            elif (
                link.startswith("#")
                or link.startswith("mailto:")
                or link.startswith("tel:")
            ):
                continue
            clean_link = "https://" + local_domain + "/" + link
        if clean_link is not None:
            if clean_link.endswith("/"):
                clean_link = clean_link[:-1]
            clean_links.append(clean_link)
    # 返回在同一域内的超链接列表
    return list(set(clean_links))
# 函数:爬取网页
def crawl(url):
    # 解析 URL 并获取域名
    local_domain = urlparse(url).netloc
    # 创建一个队列来存储要爬取的 URL
    queue = deque([url])
    # 创建一个集合来存储已经看过的 URL(无重复)
    seen = set([url])
    # 创建一个目录来存储文本文件
    if not os.path.exists("text/"):
            os.mkdir("text/")
    if not os.path.exists("text/"+local_domain+"/"):
            os.mkdir("text/" + local_domain + "/")
    # 创建一个目录来存储 CSV 文件
    if not os.path.exists("processed"):
            os.mkdir("processed")
    # 当队列非空时,继续爬取
    while queue:
        # 从队列中获取下一个 URL
        url = queue.pop()
        print(url)  # 用于调试和查看进度
        # 尝试从链接中提取文本,如果失败则继续处理队列中的下一项
        try:
            # 将来自 URL 的文本保存到 <url>.txt 文件中
            with open('text/'+local_domain+'/'+url[8:].replace("/", "_") + ".txt", "w", encoding="UTF-8") as f:
                # 使用 BeautifulSoup 从 URL 获取文本
                soup = BeautifulSoup(requests.get(url).text, "html.parser")
                container = soup.find(class_="container")

                if container is not None:
                    text = container.get_text()
                else:
                    text = ""
                # # 获取文本但去除标签
                # text = soup.get_text()
                # 如果爬虫遇到需要 JavaScript 的页面,它将停止爬取
                if ("You need to enable JavaScript to run this app." in text):
                    print("由于需要 JavaScript,无法解析页面 " + url)  
                # 否则,将文本写入到文本目录中的文件中
                f.write(text)
        except Exception as e:
            print("无法解析页面 " + url)

        # 获取 URL 的超链接并将它们添加到队列中
        for link in get_domain_hyperlinks(local_domain, url):
            if link not in seen:
                queue.append(link)
                seen.add(link)
crawl(full_url)
# 函数:移除字符串中的换行符
def remove_newlines(serie):
    serie = serie.str.replace('\n', ' ')
    serie = serie.str.replace('\\n', ' ')
    serie = serie.str.replace('  ', ' ')
    serie = serie.str.replace('  ', ' ')
    return serie
# 创建一个列表来存储文本文件
texts=[]
# 获取文本目录中的所有文本文件
for file in os.listdir("text/" + domain + "/"):
    # 打开文件并读取文本
    with open("text/" + domain + "/" + file, "r", encoding="UTF-8") as f:
        text = f.read()
        # 忽略前 11 行和后 4 行,然后替换 -、_ 和 #update 为空格
        texts.append((file[11:-4].replace('-',' ').replace('_', ' ').replace('#update',''), text))
  • 爬取域名下所有文字数据的Selenium版本(爬取某些反爬网页有奇效,就是慢一点)
import requests
from bs4 import BeautifulSoup
from msedge.selenium_tools import Edge, EdgeOptions
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.by import By
from urllib.parse import urlparse
import os
from collections import deque
import re
# Selenium 配置
options = EdgeOptions()
# 禁用地理位置请求
prefs = {"profile.default_content_setting_values.geolocation": 2}
options.add_experimental_option("prefs", prefs)
options.add_argument('--ignore-certificate-errors')
options.use_chromium = True
# options.add_argument("--headless")
options.add_argument("--disable-gpu")
options.add_argument("--disable-extensions")
# 正则表达式模式,用于匹配URL
HTTP_URL_PATTERN = r'^http[s]{0,1}://.+$'
# 定义要爬取的根域名
domain = "python.langchain.com"
base_url = "https://python.langchain.com/docs/"
# 函数:从 URL 获取超链接
def get_hyperlinks(url, browser):
    try:
        browser.get(url)
        # WebDriverWait(browser, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
        WebDriverWait(browser, 10).until(EC.presence_of_element_located((By.CLASS_NAME, "menu")))
        html = browser.page_source
    except Exception as e:
        print(e)
        return []
    soup = BeautifulSoup(html, 'html.parser')
    links = [a.get('href') for a in soup.find_all('a', href=True)]
    return links
# 函数:获取在同一域内的 URL 的超链接
def get_domain_hyperlinks(local_domain, base_url, url, browser):
    clean_links = []
    for link in set(get_hyperlinks(url, browser)):
        clean_link = None
        # 检查链接是否为完整的 URL
        if re.search(HTTP_URL_PATTERN, link):
            url_obj = urlparse(link)
            if url_obj.netloc == local_domain:
                # 检查链接是否是基础 URL 或其锚点的变化
                if link.startswith(base_url) or (url_obj.path == urlparse(base_url).path and url_obj.fragment):
                    clean_link = link
        else:
            # 处理相对链接
            if link.startswith("/"):
                complete_link = "https://" + local_domain + link
                if complete_link.startswith(base_url):
                    clean_link = complete_link
            elif link.startswith("#"):
                # 处理锚点链接
                clean_link = base_url + link
            elif link.startswith("mailto:") or link.startswith("tel:"):
                continue
        if clean_link is not None:
            clean_links.append(clean_link)
    return list(set(clean_links))
# 函数:保存文本到文件
def save_text_to_file(url, text, local_domain):
    # 检查文本长度,如果太短或为空则不保存
    if len(text.strip()) < 50:
        print(f"文本内容太少,不保存: {url}")
        return
    base_dir = "text"
    domain_dir = os.path.join(base_dir, local_domain)
    if not os.path.exists(domain_dir):
        os.makedirs(domain_dir)
    filename = f"{domain_dir}/{url.replace('https://', '').replace('/', '_')}.txt"
    with open(filename, 'w', encoding='utf-8') as file:
        file.write(text)
    print(f"已保存: {filename}")
# 函数:爬取网页
def crawl(base_url):
    local_domain = urlparse(base_url).netloc
    queue = deque([base_url])
    seen = set([base_url])
    # 启动 Edge 浏览器
    browser = Edge(options=options)
    try:
        while queue:
            url = queue.pop()
            try:
                browser.get(url)
                WebDriverWait(browser, 60).until(EC.presence_of_element_located((By.CLASS_NAME, "theme-doc-markdown")))
                html = browser.page_source
                # 使用 BeautifulSoup 解析和处理 HTML
                soup = BeautifulSoup(html, 'html.parser')
                # 专门提取 <main> 标签的内容
                main_content = soup.find('div',class_="theme-doc-markdown markdown")
                if main_content is not None:
                    text = main_content.get_text()
                else:
                    continue
                # 保存文本到文件
                save_text_to_file(url, text, local_domain)
            except Exception as e:
                print("无法解析页面:", url, "; 错误:", e)
                continue
            for link in get_domain_hyperlinks(local_domain, base_url, url, browser):
                if link not in seen:
                    queue.append(link)
                    seen.add(link)
    finally:
        browser.quit()
crawl(base_url)

其中有 mp4 格式,md 格式,以及 pdf 格式,对这些文件的加载方式,该项目将代码放在了 project/database/create_db.py文件 下,部分代码如下。其中 pdf 格式文件用 PyMuPDFLoader 加载器,md格式文件用UnstructuredMarkdownLoader加载器:

from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.vectorstores import Chroma
```python
# 首先实现基本配置
```python
DEFAULT_DB_PATH = "../../data_base/knowledge_db"
DEFAULT_PERSIST_PATH = "../../data_base/vector_db"
... 
...
...
def file_loader(file, loaders):
    if isinstance(file, tempfile._TemporaryFileWrapper):
        file = file.name
    if not os.path.isfile(file):
        [file_loader(os.path.join(file, f), loaders) for f in  os.listdir(file)]
        return
    file_type = file.split('.')[-1]
    if file_type == 'pdf':
        loaders.append(PyMuPDFLoader(file))
    elif file_type == 'md':
        pattern = r"不存在|风控"
        match = re.search(pattern, file)
        if not match:
            loaders.append(UnstructuredMarkdownLoader(file))
    elif file_type == 'txt':
        loaders.append(UnstructuredFileLoader(file))
    return
(2)文本分割和向量化

文本分割和向量化将上述载入的知识库文本或进行 token 长度进行分割,该项目利用 Langchain 中的文本分割器根据 chunk_size (块大小)和 chunk_overlap (块与块之间的重叠大小)进行分割。

  • chunk_size 指每个块包含的字符或 Token(如单词、句子等)的数量
  • chunk_overlap 指两个块之间共享的字符数量,用于保持上下文的连贯性,避免分割丢失上下文信息

tip:可以设置一个最大的 Token 长度,然后根据这个最大的 Token 长度来切分文档。这样切分出来的文档片段是一个一个均匀长度的文档片段。而片段与片段之间的一些重叠的内容,能保证检索的时候能够检索到相关的文档片段。

这部分文本分割代码也在 project/database/create_db.py 文件,该项目采用了 langchain 中 RecursiveCharacterTextSplitter 文本分割器进行分割。代码如下:

def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
    """
    该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。

    参数:
    file: 存放文件的路径。
    embeddings: 用于生产 Embedding 的模型

    返回:
    vectordb: 创建的数据库。
    """
    if files == None:
        return "can't load empty file"
    if type(files) != list:
        files = [files]
    loaders = []
    [file_loader(file, loaders) for file in files]
    docs = []
    for loader in loaders:
        if loader is not None:
            docs.extend(loader.load())
    # 切分文档
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=150)
    split_docs = text_splitter.split_documents(docs)

而在切分好知识库文本之后,需要对文本进行 向量化,文本向量化代码文件路径是project/embedding/call_embedding.py ,文本嵌入方式可选本地 m3e 模型,以及调用 openaizhipuaiapi 的方式进行文本嵌入。代码如下:

import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(r"../../")
from embedding.zhipuai_embedding import ZhipuAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from llm.call_llm import parse_llm_api_key


def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None):
   if embedding == 'm3e':
      return HuggingFaceEmbeddings(model_name="moka-ai/m3e-base")
   if embedding_key == None:
      embedding_key = parse_llm_api_key(embedding)
   if embedding == "openai":
      return OpenAIEmbeddings(openai_api_key=embedding_key)
   elif embedding == "zhipuai":
      return ZhipuAIEmbeddings(zhipuai_api_key=embedding_key)
   else:
      raise ValueError(f"embedding {embedding} not support ")

读者也可自行配置Emebdding模型;

(3)向量数据库

在对知识库文本进行分割和向量化后,就需要定义一个向量数据库用来存放文档片段和对应的向量表示了,在向量数据库中,数据被表示为向量形式,每个向量代表一个数据项。这些向量可以是数字、文本、图像或其他类型的数据。

向量数据库使用高效的索引和查询算法来加速向量数据的存储和检索过程
该项目选择 chromadb 向量数据库(类似的向量数据库还有 faiss 😏等)。定义向量库对应的代码也在 /database/create_db.py 文件中,persist_directory 即为本地持久化地址,vectordb.persist() 操作可以持久化向量数据库到本地,后续可以再次载入本地已有的向量库。完整的文本分割,获取向量化,并且定义向量数据库代码如下:

def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
    """
    该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
    参数:
    file: 存放文件的路径。
    embeddings: 用于生产 Embedding 的模型

    返回:
    vectordb: 创建的数据库。
    """
    if files == None:
        return "can't load empty file"
    if type(files) != list:
        files = [files]
    loaders = []
    [file_loader(file, loaders) for file in files]
    docs = []
    for loader in loaders:
        if loader is not None:
            docs.extend(loader.load())
    # 切分文档
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=150)
    split_docs = text_splitter.split_documents(docs)
    if type(embeddings) == str:
        embeddings = get_embedding(embedding=embeddings)
    # 定义持久化路径
    persist_directory = '../../data_base/vector_db/chroma'
    # 加载数据库
    vectordb = Chroma.from_documents(
    documents=split_docs,
    embedding=embeddings,
    persist_directory=persist_directory  # 允许我们将persist_directory目录保存到磁盘上
    ) 

    vectordb.persist()
    return vectordb
③向量检索和生成

进入了 RAG 的检索和生成阶段,即对问句 Query 向量化后在知识库文档向量中匹配出与问句 Query 向量最相似的 top k 个片段,**检索出知识库文本文本作为上下文 Context 和问题⼀起添加到 prompt 中,然后提交给 LLM 生成回答 **。

(1)向量数据库检索

接下去利用向量数据库来进行高效的检索。

向量数据库是一种用于有效搜索大规模高维向量空间中相似度的库,能够在大规模数据集中快速找到与给定 query 向量最相似的向量。

代码如下所示:

question="什么是机器学习"
sim_docs = vectordb.similarity_search(question,k=3)
print(f"检索到的内容数:{len(sim_docs)}")
for i, sim_doc in enumerate(sim_docs):
    print(f"检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")

运行结果:

检索到的内容数:3
检索到的第0个内容: 
导,同时也能体会到这三门数学课在机器学习上碰撞产生的“数学之美”。
1.1
引言
本节以概念理解为主,在此对“算法”和“模型”作补充说明。“算法”是指从数据中学得“模型”的具
体方法,例如后续章节中将会讲述的线性回归、对数几率回归、决策树等。“算法”产出的结果称为“模型”,
通常是具体的函数或者可抽象地看作为函数,例如一元线性回归算法产出的模型即为形如 f(x) = wx + b

的一元一次函数。
--------------

检索到的第1个内容: 
模型:机器学习的一般流程如下:首先收集若干样本(假设此时有 100 个),然后将其分为训练样本
(80 个)和测试样本(20 个),其中 80 个训练样本构成的集合称为“训练集”,20 个测试样本构成的集合
称为“测试集”,接着选用某个机器学习算法,让其在训练集上进行“学习”(或称为“训练”),然后产出

得到“模型”(或称为“学习器”),最后用测试集来测试模型的效果。执行以上流程时,表示我们已经默
--------------

检索到的第2个内容: 
→_→
欢迎去各大电商平台选购纸质版南瓜书《机器学习公式详解》
←_←
第 1 章
绪论
本章作为“西瓜书”的开篇,主要讲解什么是机器学习以及机器学习的相关数学符号,为后续内容作
铺垫,并未涉及复杂的算法理论,因此阅读本章时只需耐心梳理清楚所有概念和数学符号即可。此外,在
阅读本章前建议先阅读西瓜书目录前页的《主要符号表》,它能解答在阅读“西瓜书”过程中产生的大部
分对数学符号的疑惑。
本章也作为
(2)大模型llm的调用

以该项目 project/qa_chain/model_to_llm.py 代码为例,**在 project/llm/ 的目录文件夹下分别定义了 星火spark,智谱glm,文心llm等开源模型api调用的封装,**并在 project/qa_chain/model_to_llm.py 文件中导入了这些模块,可以根据用户传入的模型名字进行调用 llm。代码如下:

def model_to_llm(model:str=None, temperature:float=0.0, appid:str=None, api_key:str=None,Spark_api_secret:str=None,Wenxin_secret_key:str=None):
        """
        星火:model,temperature,appid,api_key,api_secret
        百度问心:model,temperature,api_key,api_secret
        智谱:model,temperature,api_key
        OpenAI:model,temperature,api_key
        """
        if model in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"]:
            if api_key == None:
                api_key = parse_llm_api_key("openai")
            llm = ChatOpenAI(model_name = model, temperature = temperature , openai_api_key = api_key)
        elif model in ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"]:
            if api_key == None or Wenxin_secret_key == None:
                api_key, Wenxin_secret_key = parse_llm_api_key("wenxin")
            llm = Wenxin_LLM(model=model, temperature = temperature, api_key=api_key, secret_key=Wenxin_secret_key)
        elif model in ["Spark-1.5", "Spark-2.0"]:
            if api_key == None or appid == None and Spark_api_secret == None:
                api_key, appid, Spark_api_secret = parse_llm_api_key("spark")
            llm = Spark_LLM(model=model, temperature = temperature, appid=appid, api_secret=Spark_api_secret, api_key=api_key)
        elif model in ["chatglm_pro", "chatglm_std", "chatglm_lite"]:
            if api_key == None:
                api_key = parse_llm_api_key("zhipuai")
            llm = ZhipuAILLM(model=model, zhipuai_api_key=api_key, temperature = temperature)
        else:
            raise ValueError(f"model{model} not support!!!")
        return llm
(3)prompt和构建问答链

接下去来到了最后一步,设计完基于知识库问答的 prompt,就可以结合上述检索和大模型调用进行答案的生成。构建 prompt 的格式如下,具体可以根据自己业务需要进行修改:

from langchain.prompts import PromptTemplate
# template = """基于以下已知信息,简洁和专业的来回答用户的问题。
#             如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分。
#             答案请使用中文。
#             总是在回答的最后说“谢谢你的提问!”。
# 已知信息:{context}
# 问题: {question}"""
template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答
案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
{context}
问题: {question}
有用的回答:"""

QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"],
                                 template=template)

# 运行 chain
并且构建问答链:创建检索 QA 链的方法 RetrievalQA.from_chain_type() 有如下参数:

参数介绍

  • llm:指定使用的 LLM
  • chain type :RetrievalQA.from_chain_type(chain_type=“map_reduce”),
  • 自定义 prompt :通过在RetrievalQA.from_chain_type()方法中,指定chain_type_kwargs参数,而该参数:chain_type_kwargs = {“prompt”: PROMPT}
  • 返回源文档:通过RetrievalQA.from_chain_type()方法中指定:return_source_documents=True参数;也可以使用RetrievalQAWithSourceChain()方法,返回源文档的引用(坐标或者叫主键、索引)
#自定义 QA 链
self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm,
                                        retriever=self.retriever,
                                        return_source_documents=True,
                                        chain_type_kwargs={"prompt":self.QA_CHAIN_PROMPT})

问答链效果如下:基于召回结果和 query 结合起来构建的 prompt 效果

question_1 = "什么是南瓜书?"
question_2 = "王阳明是谁?"
result = qa_chain({"query": question_1})
print("大模型+知识库后回答 question_1 的结果:")
print(result["result"])
大模型+知识库后回答 question_1 的结果:
南瓜书是对《机器学习》(西瓜书)中难以理解的公式进行解析和补充推导细节的一本书。谢谢你的提问!
result = qa_chain({"query": question_2})
print("大模型+知识库后回答 question_2 的结果:")
print(result["result"])
大模型+知识库后回答 question_2 的结果:
我不知道王阳明是谁,谢谢你的提问!

以上检索问答链代码都在project/qa_chain/QA_chain_self.py 中,此外该项目还实现了带记忆的检索问答链,两种自定义检索问答链内部实现细节类似,只是调用了不同的 LangChain 链。完整带记忆的检索问答链条代码 project/qa_chain/Chat_QA_chain_self.py 如下:

from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI

from qa_chain.model_to_llm import model_to_llm
from qa_chain.get_vectordb import get_vectordb


class Chat_QA_chain_self:
    """"
    带历史记录的问答链  
    - model:调用的模型名称
    - temperature:温度系数,控制生成的随机性
    - top_k:返回检索的前k个相似文档
    - chat_history:历史记录,输入一个列表,默认是一个空列表
    - history_len:控制保留的最近 history_len 次对话
    - file_path:建库文件所在路径
    - persist_path:向量数据库持久化路径
    - appid:星火
    - api_key:星火、百度文心、OpenAI、智谱都需要传递的参数
    - Spark_api_secret:星火秘钥
    - Wenxin_secret_key:文心秘钥
    - embeddings:使用的embedding模型
    - embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)  
    """
    def __init__(self,model:str, temperature:float=0.0, top_k:int=4, chat_history:list=[], file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai",embedding_key:str=None):
        self.model = model
        self.temperature = temperature
        self.top_k = top_k
        self.chat_history = chat_history
        #self.history_len = history_len
        self.file_path = file_path
        self.persist_path = persist_path
        self.appid = appid
        self.api_key = api_key
        self.Spark_api_secret = Spark_api_secret
        self.Wenxin_secret_key = Wenxin_secret_key
        self.embedding = embedding
        self.embedding_key = embedding_key


        self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key)
        
    
    def clear_history(self):
        "清空历史记录"
        return self.chat_history.clear()

    
    def change_history_length(self,history_len:int=1):
        """
        保存指定对话轮次的历史记录
        输入参数:
        - history_len :控制保留的最近 history_len 次对话
        - chat_history:当前的历史对话记录
        输出:返回最近 history_len 次对话
        """
        n = len(self.chat_history)
        return self.chat_history[n-history_len:]

 
    def answer(self, question:str=None,temperature = None, top_k = 4):
        """"
        核心方法,调用问答链
        arguments: 
        - question:用户提问
        """
        
        if len(question) == 0:
            return "", self.chat_history
        
        if len(question) == 0:
            return ""
        
        if temperature == None:
            temperature = self.temperature

        llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key)

        #self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

        retriever = self.vectordb.as_retriever(search_type="similarity",   
                                        search_kwargs={'k': top_k})  #默认similarity,k=4

        qa = ConversationalRetrievalChain.from_llm(
            llm = llm,
            retriever = retriever
        )

        #print(self.llm)
        result = qa({"question": question,"chat_history": self.chat_history})       #result里有question、chat_history、answer
        answer =  result['answer']
        self.chat_history.append((question,answer)) #更新历史记录

        return self.chat_history  #返回本次回答和更新后的历史记录

OK,时间有限,分析完毕,各位读者有兴趣可以学习一下,给官方一个Star😀😀😀;

文章参考

项目地址


thank_watch

如果觉得我的文章对您有帮助,三连+关注便是对我创作的最大鼓励!或者一个star🌟也可以😂.

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐