ADD file via upload

main
pycsq8k9h 8 months ago
parent 5a32472cd5
commit 37b9a27ca2

966
web.py

@ -0,0 +1,966 @@
from __future__ import annotations
import os
from PIL import Image
import torch.nn as nn
from torchvision import models, transforms
from pathlib import Path
from threading import Thread
from typing import Union
import gradio as gr
import torch
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer, AutoModel
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.document_loaders import TextLoader
from langchain.chains import create_retrieval_chain
from langchain.tools.retriever import create_retriever_tool
import asyncio
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging
import pdfplumber
from functools import partial
from importlib.metadata import version
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from packaging.version import parse
logger = logging.getLogger(__name__)
huggingface_path = "/home/wangty/lsx/temp-lsx/word_language_model/models--sentence-transformers--all-mpnet-base-v2"
"""ZhipuAI与langchain部分不兼容重写langchain部分代码"""
def is_zhipu_v2() -> bool:
"""Return whether zhipu API is v2 or more."""
_version = parse(version("zhipuai"))
return _version.major >= 2
def _create_retry_decorator(
llm: ChatZhipuAI,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import zhipuai
errors = [
zhipuai.ZhipuAIError,
zhipuai.APIStatusError,
zhipuai.APIRequestFailedError,
zhipuai.APIReachLimitError,
zhipuai.APIInternalError,
zhipuai.APIServerFlowExceedError,
zhipuai.APIResponseError,
zhipuai.APIResponseValidationError,
zhipuai.APITimeoutError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
additional_kwargs: Dict = {}
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
class ChatZhipuAI(BaseChatModel):
"""
`ZHIPU AI` large language chat models API.
To use, you should have the ``zhipuai`` python package installed.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatZhipuAI
zhipuai_chat = ChatZhipuAI(
temperature=0.5,
api_key="your-api-key",
model_name="glm-3-turbo",
)
"""
zhipuai: Any
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field("glm-3-turbo", alias="model")
"""
Model name to use.
-glm-3-turbo:
According to the input of natural language instructions to complete a
variety of language tasks, it is recommended to use SSE or asynchronous
call request interface.
-glm-4:
According to the input of natural language instructions to complete a
variety of language tasks, it is recommended to use SSE or asynchronous
call request interface.
"""
temperature: float = Field(0.95)
"""
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
be equal to 0.
The larger the value, the more random and creative the output; The smaller
the value, the more stable or certain the output will be.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
top_p: float = Field(0.7)
"""
Another method of sampling temperature is called nuclear sampling. The value
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
The model considers the results with top_p probability quality tokens.
For example, 0.1 means that the model decoder only considers tokens from the
top 10% probability of the candidate set.
You are advised to adjust top_p or temperature parameters based on application
scenarios, but do not adjust the two parameters at the same time.
"""
request_id: Optional[str] = Field(None)
"""
Parameter transmission by the client must ensure uniqueness; A unique
identifier used to distinguish each request, which is generated by default
by the platform when the client does not transmit it.
"""
do_sample: Optional[bool] = Field(True)
"""
When do_sample is true, the sampling policy is enabled. When do_sample is false,
the sampling policy temperature and top_p are disabled
"""
streaming: bool = Field(False)
"""Whether to stream the results or not."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
max_tokens: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return the type of chat model."""
return "zhipuai"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "zhipuai"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.model_name:
attributes["model"] = self.model_name
if self.streaming:
attributes["streaming"] = self.streaming
if self.max_tokens:
attributes["max_tokens"] = self.max_tokens
return attributes
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling ZhipuAI API."""
params = {
"model": self.model_name,
"stream": self.streaming,
"temperature": self.temperature,
"top_p": self.top_p,
"do_sample": self.do_sample,
**self.model_kwargs,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the zhipuai client."""
zhipuai_creds: Dict[str, Any] = {
"request_id": self.request_id,
}
return {**self._default_params, **zhipuai_creds}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
from zhipuai import ZhipuAI
if not is_zhipu_v2():
raise RuntimeError(
"zhipuai package version is too low"
"Please install it via 'pip install --upgrade zhipuai'"
)
self.client = ZhipuAI(
api_key=self.zhipuai_api_key, # 填写您的 APIKey
)
except ImportError:
raise RuntimeError(
"Could not import zhipuai package. "
"Please install it via 'pip install zhipuai'"
)
def completions(self, **kwargs) -> Any | None:
return self.client.chat.completions.create(**kwargs)
async def async_completions(self, **kwargs) -> Any:
loop = asyncio.get_running_loop()
partial_func = partial(self.client.chat.completions.create, **kwargs)
response = await loop.run_in_executor(
None,
partial_func,
)
return response
async def async_completions_result(self, task_id):
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
self.client.asyncCompletions.retrieve_completion_result,
task_id,
)
return response
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))
if "index" in res:
generation_info["index"] = res["index"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
"model_name": self.model_name,
"task_id": response.get("id", ""),
"created_time": response.get("created", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.completions(**kwargs)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
self,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await self.async_completions(**kwargs)
return await _completion_with_retry(**kwargs)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a chat response."""
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = False,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate a chat response."""
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await self.acompletion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the chat response in chunks."""
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
"""问答系统界面"""
# 自定义CSS
custom_css = """
#gr-chatbot {
background-color: #FFFFFF !important; /* 设置Chatbot组件背景为浅灰色 */
color: #000000 !important; /* 设置Chatbot字体颜色为黑色 */
border: 2px solid #ccc; /* 设置Chatbot边框 */
border-radius: 20px; /* 设置圆角边框 */
padding: 10px;
height: 1600px;
display: block !important; /* 强制为块级元素 */
}
#example-btn {
width: 250px !important; /* 按钮宽度变窄 */
font-size: 20px !important; /* 按钮字体大小 */
font-family: "仿宋", serif !important; /* 设置字体为宋体 */
padding: 5px !important; /* 调整按钮内边距 */
margin-bottom: 5px; /* 按钮之间的间距 */
background-color: #E4E4E7 !important;
border: none !important; /* 去掉按钮的边框 */
}
#example-container {
border: 2px solid #ccc; /* 设置方框边框 */
padding:10px; /* 增加内边距 */
border-radius: 30px; /* 设置圆角 */
margin-bottom: 10px; /* 与其他元素保持间距 */
width: 285px !important; /* 固定按钮容器宽度 */
height: 460px !important; /* 固定按钮容器高度 */
overflow-y: auto !important; /* 自动显示垂直滚动条 */
overflow-x: hidden !重要; /* 禁用水平滚动条 */
display: block !important; /* 强制为块级元素 */
background-color: #E4E4E7 !important;
}
#send-btn {
width: 60px !important; /* 缩小按钮宽度 */
height: 60px !important; /* 缩小按钮的高度 */
font-size: 20px; /* 调整按钮字体大小 */
font-family: "仿宋", serif !important; /* 设置字体为宋体 */
padding: 5px; /* 设置按钮的内边距 */
vertical-align: middle; /* 按钮内容垂直居中 */
background-color: #FFFFFF !important; /* 按钮背景为白色 */
border: 1px solid #CCCCCC !important; /* 将按钮的边框设置为黑色 */
}
#textbox {
height: 60px !重要; /* 设置输入框的高度 */
line-height: normal !重要; /* 确保行高正常 */
font-size: 20px !重要; /* 调整输入框的字体大小 */
}
"""
examples = [
"请问感冒的症状有哪些?",
"如何预防高血压?",
"治疗糖尿病的常见方法是什么?",
"什么是流感?",
"什么是偏头痛的常见诱因?",
"支气管炎怎么治疗?",
"冠心病的早期症状是什么?",
"治疗哮喘有哪些常见药物?",
"癫痫病如何治疗?",
"高血压不能吃什么?"
]
# HTML内容标题的定义
html_content = """
<h1 style="text-align: center;">医疗聊天机器人</h1>
"""
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', '/home/wangty/lsx/temp-lsx/word_language_model/medical_demo')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
"""重写StopTokens部分停止id不一致"""
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
ZHIPUAI_API_KEY = "4a1f42142b2adfd2bdcbb864eb8ddf8f.6o0LAkhsANjQHGlJ"
llm = ChatZhipuAI(
temperature=0.1,
api_key=ZHIPUAI_API_KEY,
model_name="glm-4",
)
prompt = ChatPromptTemplate.from_template("""仅根据所提供的上下文回答以下问题,并且禁止添加自己认为的信息,如果上下文没有,输出暂无结果:
<context>
{context}
</context>
问题: {input}
""")
output_parser = StrOutputParser()
embeddings = HuggingFaceEmbeddings(model_name=huggingface_path)
text_splitter = RecursiveCharacterTextSplitter()
def read_pdf(file_path):
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
def original(data: str) -> str:
prompt_none = ChatPromptTemplate.from_messages([
("system", "您是一个助手,请回答用户的问题。"),
("user", "{input}")
])
output_parser_original = StrOutputParser()
chain = prompt_none | llm | output_parser_original
response = chain.invoke({"input": data}) # 调用链条,并传入数据
return response
def process_uploaded_file(uploaded_file, question):
# 生成原始模型输出
original_answer = original(question)
file_path = uploaded_file.name
if file_path.endswith(".pdf"):
text = read_pdf(file_path)
else:
loader = TextLoader(file_path)
text = loader.load()[0].page_content
# 文本分割和嵌入处理
documents = text_splitter.split_text(text)
vector = FAISS.from_texts(documents, embeddings)
retriever = vector.as_retriever() # 返回相关片段
document_chain = create_stuff_documents_chain(llm, prompt)
retrieval_chain = create_retrieval_chain(retriever, document_chain)
response = retrieval_chain.invoke({"input": question})
# 获取问题的嵌入向量
question_embedding = embeddings.embed_query(question)
# 获取所有文档的嵌入向量
document_embeddings = embeddings.embed_documents(documents)
# 计算每个文档与问题之间的余弦相似度
scores = cosine_similarity([question_embedding], document_embeddings)[0]
# 打印所有文档的分数以进行调试
print("所有文档的相似度得分:")
for i, (doc, score) in enumerate(zip(documents, scores)):
print(f"段落 {i + 1}:")
print(f"内容: {doc[:100]}...") # 打印前100个字符避免输出过多
print(f"得分: {score}")
print("-" * 40) # 分隔线,方便阅读
# 筛选出得分大于等于 0.1 的文档(暂时降低阈值以确保有输出)
high_score_docs = [
(doc, score) for doc, score in zip(documents, scores) if score >= 0.3
]
# 如果没有高分文档,返回默认信息
if not high_score_docs:
return original_answer, "未找到相关信息", "未找到相关上下文"
# 提取高分文档上下文内容
context = "\n\n".join([doc for doc, _ in high_score_docs[:3]])
# 使用上下文调用模型
# 检查回答内容与置信度
answer = response.get("answer", "").strip()
confidence = response.get("confidence", 1.0)
# 如果无答案或置信度低于阈值,返回默认信息
if not answer or confidence < 0.5:
answer = "未找到相关信息"
return original_answer, answer, context
def predict(history):
stop = StopOnTokens()
messages = []
# 在模型生成响应之前输出历史记录
for idx, (user_msg, model_msg) in enumerate(history):
# 在处理 history 时,确保每个条目是 [user_message, model_message] 的格式
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
# 如果 messages 为空,直接返回
if not messages:
return history
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": 8192, # 固定的 max_length 值
"do_sample": True,
"top_p": 0.7, # 固定的 top_p 值
"temperature": 0.9, # 固定的 temperature 值
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for new_token in streamer:
# 确保 history 的最后一项是 [user_message, model_message] 的格式
if len(history[-1]) < 2:
history[-1].append(new_token) # 添加模型的响应
else:
history[-1][1] += new_token # 如果已经有模型响应,继续追加
# 输出生成新 token 后的历史记录
yield history
# 图片分类使用Resnet进行图片分类
# path/home/wangty/lsx/temp-lsx/word_language_model/recog_image/best_resnet18_model.pth
model_resnet = models.resnet18(weights=None)
model_resnet = nn.DataParallel(model_resnet)
num_classes = 7
model_resnet.module.fc = nn.Linear(model_resnet.module.fc.in_features, num_classes)
model_resnet.load_state_dict(torch.load('/home/wangty/lsx/temp-lsx/recog_image/code/best_resnet18_model.pth', map_location=torch.device('cuda')))
model_resnet = model_resnet.to('cuda:0')
model_resnet.eval()
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_names = {0: 'akiec', 1: 'bcc', 2: 'bkl', 3: 'df', 4: 'mel', 5: 'nv', 6: 'vasc'}
translations = { # 翻译映射
"akiec": "光化性角化病和皮内癌 / 鲍温病",
"bcc": "基底细胞癌",
"bkl": "良性角化样病变",
"df": "皮肤纤维瘤",
"mel": "黑色素瘤",
"nv": "黑素细胞痣",
"vasc": "血管病变",
}
def translate(name):
translated_text = " ".join([translations.get(word, word) for word in name.split()])
return translated_text
def predict_image(image):
# 检查并确保图像是 RGB 模式
if not image.mode == 'RGB':
image = image.convert('RGB') # 转换到 RGB 模式
# 应用图像预处理
image = transform(image).unsqueeze(0) # 转换成批次格式并增加维度
# 确保模型是图像分类模型
model_resnet.eval()
with torch.no_grad():
outputs = model_resnet(image) # 传入模型并获取输出
_, predicted = torch.max(outputs, 1)
predicted_class = class_names[predicted.item()]
name = translate(predicted_class)
# prompt_image = ChatPromptTemplate.from_messages([
# ("system", "你是一个专业的医疗专家,解释问题是什么,该如何治疗。文本开头加入根据您的图片分析等字样"),
# ("user", "{input}")
# ])
# output_parser_original = StrOutputParser()
# chain = prompt_image | llm | output_parser_original
# response = chain.invoke({"input": name}) # 调用链条,并传入数据
stop = StopOnTokens()
messages = [{"role": "user", "content": name}] # 只使用用户输入的消息
# 处理消息以生成模型输入2
model_inputs = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt").to(next(model.parameters()).device)
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": model_inputs,
"streamer": streamer,
"max_new_tokens": 8192, # 固定的 max_length 值
"do_sample": True,
"top_p": 0.7, # 固定的 top_p 值
"temperature": 0.9, # 固定的 temperature 值
"stopping_criteria": StoppingCriteriaList([stop]),
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
model_reply = ""
for new_token in streamer:
model_reply += new_token # 收集模型的回复
return name,model_reply # 返回最终的模型回复
# 构建Gradio界面
with gr.Blocks(css=custom_css) as demo:
with gr.Tabs():
with gr.Tab("对话问诊"):
# 添加HTML标题
gr.HTML(html_content)
with gr.Row():
with gr.Column():
with gr.Group(elem_id="example-container"):
example_buttons = [gr.Button(example, elem_id="example-btn") for example in examples]
with gr.Column(scale=10):
chatbot = gr.Chatbot(elem_id="gr-chatbot")
with gr.Row():
user_input = gr.Textbox(show_label=False, placeholder="请输入咨询的问题:", scale=20,
elem_id="textbox")
submitBtn = gr.Button("发送", elem_id="send-btn")
def user(query, history):
# 确保 history 中每一项都是长度为 2 的列表
if history is None:
history = []
history.append([query, ""]) # 第二个元素是空字符串,表示初始没有模型回复
return "", history
def send_example(example, history): # 处理用户点击示例按钮后自动发送
if history is None:
history = []
history.append([example, ""]) # 将示例作为输入,模型响应初始化为空
return history
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
predict, [chatbot], chatbot
)
# 修改点击事件点击时直接调用user和predict函数
for example_button, example_text in zip(example_buttons, examples):
example_button.click(
send_example,
inputs=[gr.State(value=example_text), chatbot],
outputs=chatbot
).then(
predict,
inputs=[chatbot],
outputs=chatbot
)
with gr.Tab("搜索专业文档"):
gr.Markdown("## 文件上传和模型问答")
with gr.Row():
with gr.Column():
file_input = gr.File(label="上传文件")
with gr.Column():
question_input = gr.Textbox(label="请输入问题",lines=8)
with gr.Row():
original_output = gr.Textbox(label="原模型输出")
output = gr.Textbox(label="模型回答")
with gr.Row():
context_output = gr.Textbox(label="相关上下文",scale=28)
submit_button = gr.Button("提交")
submit_button.click(process_uploaded_file, inputs=[file_input, question_input],
outputs=[original_output, output, context_output])
with gr.Tab("上传患病处图像"):
gr.Markdown("## 图像分类")
with gr.Row():
image_input = gr.Image(type="pil", label="上传患病处图像")
result_output = gr.Textbox(label="知识输出", lines=9)
with gr.Row():
image_output = gr.Textbox(label="识别结果", scale=28)
classify_button = gr.Button("识别")
classify_button.click(predict_image, inputs=image_input, outputs=[image_output, result_output])
demo.queue()
demo.launch(share=True)
Loading…
Cancel
Save