You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from langchain.document_loaders import (
CSVLoader,
PDFMinerLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader, JSONLoader,
)
from langchain.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
import os
import glob
from args import args
from langchain.vectorstores import FAISS
def embed_data(embeddings_model_name,original_data_path,preprocessed_data_path):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) # embedding模型
if does_vectorstore_exist(preprocessed_data_path):
print('++++++++++++++++++++++++++++启用本地知识库+++++++++++++++++++++++++++++++++')
# 如果存在,加载向量库。
#db = Chroma(persist_directory=preprocessed_data_path, embedding_function=embeddings)
db = FAISS.load_local(preprocessed_data_path,embeddings)
else:
print('++++++++++++++++++++++++++++保存本地知识库+++++++++++++++++++++++++++++++++')
# 如果不存在,就保存在本地
texts = load_data(original_data_path) # 文本切分
#db = Chroma.from_documents(texts, embeddings, persist_directory=preprocessed_data_path)
db = FAISS.from_documents(texts, embeddings)
db.save_local(preprocessed_data_path)
return db
def load_data(original_data_path):
data_map = {
".csv": CSVLoader,
".json": JSONLoader,
".doc": UnstructuredWordDocumentLoader,
".docx": UnstructuredWordDocumentLoader,
".txt": TextLoader,
".md": UnstructuredMarkdownLoader,
".epub": UnstructuredEPubLoader,
".pdf": PDFMinerLoader,
".ppt": UnstructuredPowerPointLoader,
".pptx": UnstructuredPowerPointLoader,
".html": UnstructuredHTMLLoader,
}
data_all = []
for _, (key, value) in enumerate(data_map.items()):
if key == ".csv":
loader_kwargs = {'encoding': 'gbk'}
loader = DirectoryLoader(original_data_path,
glob=f"*{key}",
show_progress=True, # 显示进度条
use_multithreading=True, # 多线程
loader_cls=value,
silent_errors=True, # 跳过失败加载
loader_kwargs=loader_kwargs
)
elif key == ".txt":
loader_kwargs = {'autodetect_encoding': True}
loader = DirectoryLoader(original_data_path,
glob=f"*{key}",
show_progress=True, # 显示进度条
use_multithreading=True, # 多线程
loader_cls=value,
silent_errors=True, # 跳过失败加载
loader_kwargs=loader_kwargs
)
else:
loader = DirectoryLoader(original_data_path,
glob=f"*{key}",
show_progress=True, # 显示进度条
use_multithreading=True, # 多线程
loader_cls=value,
silent_errors=True, # 跳过失败加载
)
data_one = loader.load()
data_all.append(data_one)
#length_function用于计算文本块长度的方法。chunk_size文本块的最大尺寸.chunk_overlap文本块之间的最大重叠量。
text_splitter = RecursiveCharacterTextSplitter(chunk_size=args.chunk_size,chunk_overlap=args.chunk_overlap,length_function=len)
texts = []
for one in data_all:
if one == []:
continue
else:
text = text_splitter.split_documents(one)
texts.extend(text)
return texts
#更新向量库
def dBgx(original_data_path,preprocessed_data_path,embeddings_model_name):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) # embedding模型
texts = load_data(original_data_path) # 文本切分
# db = Chroma.from_documents(texts, embeddings, persist_directory=preprocessed_data_path)
db = FAISS.from_documents(texts, embeddings)
db.save_local(preprocessed_data_path)
return db
def does_vectorstore_exist(persist_directory):
if os.path.exists(os.path.join(persist_directory, 'index.pkl')) and os.path.exists(os.path.join(persist_directory, 'index.faiss')):
return True
else:
return False