You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

460 lines
19 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
from 数据处理 import embed_data,dBgx
import random as rand
from langchain import PromptTemplate, LLMChain
from 模型 import CustomLLM
from args import args
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR
import os
import fitz
zhishiku_init1='''欢迎来到基于知识库的菜单推荐,模型的回答效果除了与模型能力有关外,还需注意创建知识库时根据数据选择合适的切分长度和重叠长度,喜好提交时尽量详细数据库里存在的内容'''
caidan_init0="""欢迎来到基于菜单的菜单推荐,询问前请先上传菜单"""
damodel_init0="""欢迎来到大模型随机推荐,输入你的喜好,模型将根据你的喜好随机推荐"""
zhinengshibie = '''欢迎使用智能识别使用前请先去配置文件填写相关的api。
api申请方法见https://blog.csdn.net/2303_79001442/article/details/132093208
别问为什么不用本地模型,问就是显存不足。显存足够的可以考虑更换自己的模型。'''
def get_answer(query,chatbot):
llm = CustomLLM()
db = embed_data(args.embeddings_model_name,
args.original_data_path,
args.preprocessed_data_path,
)
retriever = db.as_retriever(search_kwargs={"k": args.topk_1})
docs = retriever.get_relevant_documents(query)
prompt_template = """根据已知内容推荐几盘菜肴。已知内容: {adjective}"""
prompt = PromptTemplate(template=prompt_template, input_variables=["adjective"])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(adjective=docs)
b=[query,a]
chatbot.append(b)
return chatbot, ''
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
def get_answer_caidan(query, chatbot):
embeddings = HuggingFaceEmbeddings(model_name=args.embeddings_model_name)
llm = CustomLLM()
caidan = load_caidan(args.path_caidan,sentence_size=args.chunk_size) # 加载菜单
db = FAISS.from_documents(caidan, embeddings)
retriever = db.as_retriever(search_kwargs={"k": args.topk_1})
docs = retriever.get_relevant_documents(query)
prompt = """基于我提供的菜单信息和饮食喜好信息。简洁的回答我的问题。
菜单信息:{adjective}
饮食喜好信息:{xihao}
问题:{question}"""
question = '''请根据我的喜好从菜单中选出我可能喜欢的菜肴,答案中不允许出现菜单中没有的信息。'''
prompt = PromptTemplate(template=prompt, input_variables=["adjective", "xihao", "question"])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(adjective=docs, xihao=query, question=question)
b = [query, a]
chatbot.append(b)
return chatbot, ''
def get_answer_model(query, chatbot):
llm = CustomLLM()
prompt = """你现在是一位优秀的厨师,你要根据我的饮食喜好推荐几种美食。我的饮食喜好如下:{xihao}。回答只能展示美食的名字不需要过多的解释。"""
prompt = PromptTemplate(template=prompt, input_variables=["xihao"])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(xihao=query)
b = [query, a]
chatbot.append(b)
return chatbot, ''
def get_answer_guihua(query, chatbot,mode2):
llm = CustomLLM()
if mode2 == '启用':
db = embed_data(args.embeddings_model_name,
args.original_data_path_g,
args.preprocessed_data_path_g,
)
retriever = db.as_retriever(search_kwargs={"k": args.topk_1})
docs = retriever.get_relevant_documents(query)
prompt_template = """请参考以下的素材信息和已有食材为我推荐几种菜肴,并指出缺少的食材和具体的操作过程。
素材信息:{adjective}
已有食材:{shicai}"""
prompt = PromptTemplate(template=prompt_template, input_variables=["adjective",'shicai'])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(adjective=docs,shicai=query)
b = [query, a]
chatbot.append(b)
elif mode2 == '不启用':
prompt = """你是一位优秀的厨师,我将提供一些食材,你要根据食材制作几种不同的美味菜肴。提供的食材如下:{食材}。"""
prompt = PromptTemplate(template=prompt, input_variables=["食材"])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(食材=query)
b = [query, a]
chatbot.append(b)
return chatbot, ''
def get_answer_guihua2(query, chatbot,mode2):
llm = CustomLLM()
if mode2 == '启用':
db = embed_data(args.embeddings_model_name,
args.original_data_path_g,
args.preprocessed_data_path_g)
retriever = db.as_retriever(search_kwargs={"k": args.topk_1})
docs = retriever.get_relevant_documents(query)
prompt_template = """请参考以下的素材信息回答{shicai}制作过程。
素材信息:{adjective}"""
prompt = PromptTemplate(template=prompt_template, input_variables=["adjective",'shicai'])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(adjective=docs,shicai=query)
b = [query, a]
chatbot.append(b)
elif mode2 == '不启用':
prompt = """你是一位优秀的厨师,我想制作一份{shicai},请问具体应该怎么做。"""
prompt = PromptTemplate(template=prompt, input_variables=["shicai"])
llm_chain = LLMChain(prompt=prompt, llm=llm)
a = llm_chain.run(shicai=query)
b = [query, a]
chatbot.append(b)
return chatbot, ''
def ziyouduihau(query, chatbot):
payload = {
"temperature": args.temperature,
"top_k": args.topk,
}
model_config = AutoConfig.from_pretrained(args.pretrained_model_name, trust_remote_code=True)
model_config.update(payload)
model = AutoModel.from_pretrained(args.pretrained_model_name, config=model_config,trust_remote_code=True).half().to(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name, trust_remote_code=True)
model = model.eval()
prompt = query
response, history = model.chat(tokenizer, prompt, history=[])
b = [query, response]
chatbot.append(b)
return chatbot, ''
def change_data_path(data_path):
args.original_data_path = data_path
def change_data_path_g(data_path):
args.original_data_path = data_path
def change_P_data_path(data_path):
args.preprocessed_data_path = data_path
def change_P_data_path_g(data_path):
args.preprocessed_data_path_g = data_path
def change_emb_name(emb_name):
args.embeddings_model_name = emb_name
def change_search_top_k(search_top_k):
if search_top_k != '':
args.topk_1 = int(search_top_k)
def change_model_top_k(model_top_k):
if model_top_k != '':
args.topk = int(model_top_k)
def change_model_temperature(model_temperature):
args.temperature = float(model_temperature)
def change_model_name(model_name):
args.pretrained_model_name = model_name
def db_update_click1(chatbot):
embeddings_model_name =args.embeddings_model_name
original_data_path = args.original_data_path
preprocessed_data_path = args.preprocessed_data_path
dBgx(original_data_path,preprocessed_data_path,embeddings_model_name)
dbok = [None, '知识库已更新']
chatbot.append(dbok)
return chatbot
def db_update_click(chatbot):
embeddings_model_name = args.embeddings_model_name
original_data_path = args.original_data_path_g
preprocessed_data_path = args.preprocessed_data_path_g
dBgx(original_data_path,preprocessed_data_path,embeddings_model_name)
dbok = [None, '知识库已更新']
chatbot.append(dbok)
return chatbot
def change_chunk_size(chunk_size):
args.chunk_size = chunk_size
def change_chunk_overlap(chunk_overlap):
args.chunk_overlap = chunk_overlap
def clean_memory(chatbot):
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
print(e)
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
clmok = [None, '显存已清理']
chatbot.append(clmok)
return chatbot
def change_caidan_path(caidan):
args.path_caidan = caidan
def change_path_zidingyi(suijidz):
args.path_zidingyi = suijidz
def change_suiji_num(suiji_num):
args.num_select = suiji_num
def change_shitu_top_k(shitu_top_k):
args.top_num = shitu_top_k
def change_shitu_filter_threshold(shitu_filter_threshold):
args.filter_threshold = shitu_filter_threshold
def change_shitu_baike_num(shitu_baike_num):
args.baike_num = shitu_baike_num
def Random_selection1(chatbot):
content = []
with open(args.path_zidingyi, 'r',encoding='utf-8') as file:
for line in file:
content.append(line.strip())
data_len = len(content)
num_select = args.num_select
random_rows = rand.sample(range(data_len), num_select)
elements = [content[i] for i in random_rows]
shuchu='随机内容为:\n' + '\n '.join(elements)
b = [None, shuchu]
chatbot.append(b)
return chatbot
def Random_selection2(chatbot,content_list):
content_list = content_list.split("\n")
data_len = len(content_list)
num_select = args.num_select
random_rows = rand.sample(range(data_len), num_select)
elements = [content_list[i] for i in random_rows]
shuchu='随机内容为:\n' + '\n '.join(elements)
b = [None, shuchu]
chatbot.append(b)
return chatbot
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
def _get_elements(self) -> List:
def image_ocr_txt(filepath, dir_path="tmp_files"):
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
if not os.path.exists(full_dir_path):
os.makedirs(full_dir_path)
filename = os.path.split(filepath)[-1]
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False)
result = ocr.ocr(img=filepath)
ocr_result = [i[1][0] for line in result for i in line]
txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename))
with open(txt_file_path, 'w', encoding='utf-8') as fout:
fout.write("\n".join(ocr_result))
return txt_file_path
txt_file_path = image_ocr_txt(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = 100, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r'\1\n\2', text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r'\1\n\2', ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
ele2_id + 1:]
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
return ls
class UnstructuredPaddlePDFLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
def _get_elements(self) -> List:
def pdf_ocr_txt(filepath, dir_path="tmp_files"):
full_dir_path = os.path.join(os.path.dirname(filepath), dir_path)
if not os.path.exists(full_dir_path):
os.makedirs(full_dir_path)
ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False)
doc = fitz.open(filepath)
txt_file_path = os.path.join(full_dir_path, f"{os.path.split(filepath)[-1]}.txt")
img_name = os.path.join(full_dir_path, 'tmp.png')
with open(txt_file_path, 'w', encoding='utf-8') as fout:
for i in range(doc.page_count):
page = doc[i]
text = page.get_text("")
fout.write(text)
fout.write("\n")
img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
if pix.n - pix.alpha >= 4:
pix = fitz.Pixmap(fitz.csRGB, pix)
pix.save(img_name)
result = ocr.ocr(img_name)
ocr_result = [i[1][0] for line in result for i in line]
fout.write("\n".join(ocr_result))
if os.path.exists(img_name):
os.remove(img_name)
return txt_file_path
txt_file_path = pdf_ocr_txt(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(filename=txt_file_path, **self.unstructured_kwargs)
def load_caidan(filepath, sentence_size=args.chunk_size):
if filepath.lower().endswith(".pdf"):
# 暂且将paddle相关的loader改为动态加载可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
loader = UnstructuredPaddlePDFLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
# 暂且将paddle相关的loader改为动态加载可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
else:
loader = UnstructuredFileLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
return docs
def tupian_shibie(mode1,filepath,chatbot):
from aip import AipImageClassify
""" 可选参数 """
options = {}
options["top_num"] = args.top_num
options["filter_threshold"] = str(args.filter_threshold)
options["baike_num"] = args.baike_num
image = get_file_content(filepath)
if mode1 == '菜品识别':
APP_ID = args.APP_ID_c
API_KEY = args.API_KEY_c
SECRET_KEY = args.SECRET_KEY_c
AipImageClassify = AipImageClassify(APP_ID, API_KEY, SECRET_KEY)
result = AipImageClassify.dishDetect(image, options)
result = result["result"]
num = len(result)
name_list = []
probability = []
baike_url = []
description = []
for i in result:
name_list.append(i['name'])
probability.append(i['probability'])
try:
baike_url.append(i['baike_info']['baike_url'])
description.append(i['baike_info']['description'])
except:
baike_url.append('没有百科链接')
description.append('没有百科信息')
inf = '根据预测有以下几种可能:'
for i in range(num):
one = f"""
{i + 1}: {name_list[i]}
可能性为:{probability[i]}
百科链接:{baike_url[i]}
相关描述:{description[i]}
"""
inf += one + "\n"
else:
APP_ID = args.APP_ID_g
API_KEY = args.API_KEY_g
SECRET_KEY = args.SECRET_KEY_g
AipImageClassify = AipImageClassify(APP_ID, API_KEY, SECRET_KEY)
result = AipImageClassify.ingredient(image, options)
result = result["result"]
num = len(result)
name_list = []
score = []
for i in result:
name_list.append(i['name'])
score.append(i['score'])
inf = '根据预测有以下几种可能:'
for i in range(num):
one = f"""
{i + 1}: {name_list[i]}
可能性为:{score[i]}
"""
inf += one + "\n"
b = [None, inf]
chatbot.append(b)
return chatbot
def get_file_content(filePath):
with open(filePath, 'rb') as fp:
return fp.read()