From 37b9a27ca2baad64d0028de84b74780e479cff31 Mon Sep 17 00:00:00 2001 From: pycsq8k9h <1272574577@qq.com> Date: Sat, 4 Jan 2025 10:09:07 +0800 Subject: [PATCH] ADD file via upload --- web.py | 966 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 966 insertions(+) create mode 100644 web.py diff --git a/web.py b/web.py new file mode 100644 index 0000000..ef1b3f4 --- /dev/null +++ b/web.py @@ -0,0 +1,966 @@ +from __future__ import annotations +import os +from PIL import Image +import torch.nn as nn +from torchvision import models, transforms +from pathlib import Path +from threading import Thread +from typing import Union +import gradio as gr +import torch +from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, AutoModel +) +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_community.document_loaders import WebBaseLoader +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_community.document_loaders import TextLoader +from langchain.chains import create_retrieval_chain +from langchain.tools.retriever import create_retriever_tool +import asyncio +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +import logging +import pdfplumber +from functools import partial +from importlib.metadata import version +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, + generate_from_stream, +) +from langchain_core.language_models.llms import create_base_retry_decorator +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, + ToolMessage, + ToolMessageChunk, +) +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, +) +from langchain_core.pydantic_v1 import BaseModel, Field +from packaging.version import parse + +logger = logging.getLogger(__name__) +huggingface_path = "/home/wangty/lsx/temp-lsx/word_language_model/models--sentence-transformers--all-mpnet-base-v2" +"""ZhipuAI与langchain部分不兼容,重写langchain部分代码""" + + +def is_zhipu_v2() -> bool: + """Return whether zhipu API is v2 or more.""" + _version = parse(version("zhipuai")) + return _version.major >= 2 + + +def _create_retry_decorator( + llm: ChatZhipuAI, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + import zhipuai + + errors = [ + zhipuai.ZhipuAIError, + zhipuai.APIStatusError, + zhipuai.APIRequestFailedError, + zhipuai.APIReachLimitError, + zhipuai.APIInternalError, + zhipuai.APIServerFlowExceedError, + zhipuai.APIResponseError, + zhipuai.APIResponseValidationError, + zhipuai.APITimeoutError, + ] + return create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) + + +def convert_message_to_dict(message: BaseMessage) -> dict: + """Convert a LangChain message to a dictionary. + + Args: + message: The LangChain message. + + Returns: + The dictionary. + """ + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + # If tool calls only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolMessage): + message_dict = { + "role": "tool", + "content": message.content, + "tool_call_id": message.tool_call_id, + } + else: + raise TypeError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + """Convert a dictionary to a LangChain message. + + Args: + _dict: The dictionary. + + Returns: + The LangChain message. + """ + role = _dict.get("role") + if role == "user": + return HumanMessage(content=_dict.get("content", "")) + elif role == "assistant": + content = _dict.get("content", "") or "" + additional_kwargs: Dict = {} + if tool_calls := _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = tool_calls + return AIMessage(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessage(content=_dict.get("content", "")) + elif role == "tool": + additional_kwargs = {} + if "name" in _dict: + additional_kwargs["name"] = _dict["name"] + return ToolMessage( + content=_dict.get("content", ""), + tool_call_id=_dict.get("tool_call_id"), + additional_kwargs=additional_kwargs, + ) + else: + return ChatMessage(content=_dict.get("content", ""), role=role) + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + additional_kwargs: Dict = {} + if _dict.get("tool_calls"): + additional_kwargs["tool_calls"] = _dict["tool_calls"] + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "tool" or default_class == ToolMessageChunk: + return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + +class ChatZhipuAI(BaseChatModel): + """ + `ZHIPU AI` large language chat models API. + + To use, you should have the ``zhipuai`` python package installed. + + Example: + .. code-block:: python + + from langchain_community.chat_models import ChatZhipuAI + + zhipuai_chat = ChatZhipuAI( + temperature=0.5, + api_key="your-api-key", + model_name="glm-3-turbo", + ) + + """ + + zhipuai: Any + zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key") + """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided.""" + + client: Any = Field(default=None, exclude=True) #: :meta private: + + model_name: str = Field("glm-3-turbo", alias="model") + """ + Model name to use. + -glm-3-turbo: + According to the input of natural language instructions to complete a + variety of language tasks, it is recommended to use SSE or asynchronous + call request interface. + -glm-4: + According to the input of natural language instructions to complete a + variety of language tasks, it is recommended to use SSE or asynchronous + call request interface. + """ + + temperature: float = Field(0.95) + """ + What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot + be equal to 0. + The larger the value, the more random and creative the output; The smaller + the value, the more stable or certain the output will be. + You are advised to adjust top_p or temperature parameters based on application + scenarios, but do not adjust the two parameters at the same time. + """ + + top_p: float = Field(0.7) + """ + Another method of sampling temperature is called nuclear sampling. The value + ranges from 0.0 to 1.0 and cannot be equal to 0 or 1. + The model considers the results with top_p probability quality tokens. + For example, 0.1 means that the model decoder only considers tokens from the + top 10% probability of the candidate set. + You are advised to adjust top_p or temperature parameters based on application + scenarios, but do not adjust the two parameters at the same time. + """ + + request_id: Optional[str] = Field(None) + """ + Parameter transmission by the client must ensure uniqueness; A unique + identifier used to distinguish each request, which is generated by default + by the platform when the client does not transmit it. + """ + do_sample: Optional[bool] = Field(True) + """ + When do_sample is true, the sampling policy is enabled. When do_sample is false, + the sampling policy temperature and top_p are disabled + """ + streaming: bool = Field(False) + """Whether to stream the results or not.""" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + + max_tokens: Optional[int] = None + """Number of chat completions to generate for each prompt.""" + + max_retries: int = 2 + """Maximum number of retries to make when generating.""" + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return the type of chat model.""" + return "zhipuai" + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"zhipuai_api_key": "ZHIPUAI_API_KEY"} + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "chat_models", "zhipuai"] + + @property + def lc_attributes(self) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + + if self.model_name: + attributes["model"] = self.model_name + + if self.streaming: + attributes["streaming"] = self.streaming + + if self.max_tokens: + attributes["max_tokens"] = self.max_tokens + + return attributes + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling ZhipuAI API.""" + params = { + "model": self.model_name, + "stream": self.streaming, + "temperature": self.temperature, + "top_p": self.top_p, + "do_sample": self.do_sample, + **self.model_kwargs, + } + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + return params + + @property + def _client_params(self) -> Dict[str, Any]: + """Get the parameters used for the zhipuai client.""" + zhipuai_creds: Dict[str, Any] = { + "request_id": self.request_id, + } + return {**self._default_params, **zhipuai_creds} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + from zhipuai import ZhipuAI + + if not is_zhipu_v2(): + raise RuntimeError( + "zhipuai package version is too low" + "Please install it via 'pip install --upgrade zhipuai'" + ) + + self.client = ZhipuAI( + api_key=self.zhipuai_api_key, # 填写您的 APIKey + ) + except ImportError: + raise RuntimeError( + "Could not import zhipuai package. " + "Please install it via 'pip install zhipuai'" + ) + + def completions(self, **kwargs) -> Any | None: + return self.client.chat.completions.create(**kwargs) + + async def async_completions(self, **kwargs) -> Any: + loop = asyncio.get_running_loop() + partial_func = partial(self.client.chat.completions.create, **kwargs) + response = await loop.run_in_executor( + None, + partial_func, + ) + return response + + async def async_completions_result(self, task_id): + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, + self.client.asyncCompletions.retrieve_completion_result, + task_id, + ) + return response + + def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: + generations = [] + if not isinstance(response, dict): + response = response.dict() + for res in response["choices"]: + message = convert_dict_to_message(res["message"]) + generation_info = dict(finish_reason=res.get("finish_reason")) + if "index" in res: + generation_info["index"] = res["index"] + gen = ChatGeneration( + message=message, + generation_info=generation_info, + ) + generations.append(gen) + token_usage = response.get("usage", {}) + llm_output = { + "token_usage": token_usage, + "model_name": self.model_name, + "task_id": response.get("id", ""), + "created_time": response.get("created", ""), + } + return ChatResult(generations=generations, llm_output=llm_output) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = self._client_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + message_dicts = [convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: + """Use tenacity to retry the completion call.""" + + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.completions(**kwargs) + + return _completion_with_retry(**kwargs) + + async def acompletion_with_retry( + self, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Any: + """Use tenacity to retry the async completion call.""" + + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return await self.async_completions(**kwargs) + + return await _completion_with_retry(**kwargs) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a chat response.""" + + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = { + **params, + **({"stream": stream} if stream is not None else {}), + **kwargs, + } + response = self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ) + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = False, + **kwargs: Any, + ) -> ChatResult: + """Asynchronously generate a chat response.""" + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + message_dicts, params = self._create_message_dicts(messages, stop) + params = { + **params, + **({"stream": stream} if stream is not None else {}), + **kwargs, + } + response = await self.acompletion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ) + return self._create_chat_result(response) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the chat response in chunks.""" + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs, "stream": True} + + default_chunk_class = AIMessageChunk + for chunk in self.completion_with_retry( + messages=message_dicts, run_manager=run_manager, **params + ): + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + + finish_reason = choice.get("finish_reason") + generation_info = ( + dict(finish_reason=finish_reason) if finish_reason is not None else None + ) + default_chunk_class = chunk.__class__ + chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + +"""问答系统界面""" +# 自定义CSS +custom_css = """ +#gr-chatbot { + background-color: #FFFFFF !important; /* 设置Chatbot组件背景为浅灰色 */ + color: #000000 !important; /* 设置Chatbot字体颜色为黑色 */ + border: 2px solid #ccc; /* 设置Chatbot边框 */ + border-radius: 20px; /* 设置圆角边框 */ + padding: 10px; + height: 1600px; + display: block !important; /* 强制为块级元素 */ +} +#example-btn { + + width: 250px !important; /* 按钮宽度变窄 */ + font-size: 20px !important; /* 按钮字体大小 */ + font-family: "仿宋", serif !important; /* 设置字体为宋体 */ + padding: 5px !important; /* 调整按钮内边距 */ + margin-bottom: 5px; /* 按钮之间的间距 */ + background-color: #E4E4E7 !important; + border: none !important; /* 去掉按钮的边框 */ +} +#example-container { + border: 2px solid #ccc; /* 设置方框边框 */ + padding:10px; /* 增加内边距 */ + border-radius: 30px; /* 设置圆角 */ + margin-bottom: 10px; /* 与其他元素保持间距 */ + width: 285px !important; /* 固定按钮容器宽度 */ + height: 460px !important; /* 固定按钮容器高度 */ + overflow-y: auto !important; /* 自动显示垂直滚动条 */ + overflow-x: hidden !重要; /* 禁用水平滚动条 */ + display: block !important; /* 强制为块级元素 */ + background-color: #E4E4E7 !important; +} +#send-btn { + width: 60px !important; /* 缩小按钮宽度 */ + height: 60px !important; /* 缩小按钮的高度 */ + font-size: 20px; /* 调整按钮字体大小 */ + font-family: "仿宋", serif !important; /* 设置字体为宋体 */ + padding: 5px; /* 设置按钮的内边距 */ + vertical-align: middle; /* 按钮内容垂直居中 */ + background-color: #FFFFFF !important; /* 按钮背景为白色 */ + border: 1px solid #CCCCCC !important; /* 将按钮的边框设置为黑色 */ +} + +#textbox { + height: 60px !重要; /* 设置输入框的高度 */ + line-height: normal !重要; /* 确保行高正常 */ + font-size: 20px !重要; /* 调整输入框的字体大小 */ +} +""" + +examples = [ + "请问感冒的症状有哪些?", + "如何预防高血压?", + "治疗糖尿病的常见方法是什么?", + "什么是流感?", + "什么是偏头痛的常见诱因?", + "支气管炎怎么治疗?", + "冠心病的早期症状是什么?", + "治疗哮喘有哪些常见药物?", + "癫痫病如何治疗?", + "高血压不能吃什么?" +] +# HTML内容:标题的定义 +html_content = """ +

医疗聊天机器人

+""" + +ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +MODEL_PATH = os.environ.get('MODEL_PATH', '/home/wangty/lsx/temp-lsx/word_language_model/medical_demo') +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) + + +def _resolve_path(path: Union[str, Path]) -> Path: + return Path(path).expanduser().resolve() + + +def load_model_and_tokenizer( + model_dir: Union[str, Path], trust_remote_code: bool = True +) -> tuple[ModelType, TokenizerType]: + model_dir = _resolve_path(model_dir) + if (model_dir / 'adapter_config.json').exists(): + model = AutoPeftModelForCausalLM.from_pretrained( + model_dir, trust_remote_code=trust_remote_code, device_map='auto' + ) + tokenizer_dir = model.peft_config['default'].base_model_name_or_path + else: + model = AutoModelForCausalLM.from_pretrained( + model_dir, trust_remote_code=trust_remote_code, device_map='auto' + ) + tokenizer_dir = model_dir + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False + ) + return model, tokenizer + + +model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True) + +"""重写StopTokens,部分停止id不一致""" + + +class StopOnTokens(StoppingCriteria): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + stop_ids = model.config.eos_token_id + for stop_id in stop_ids: + if input_ids[0][-1] == stop_id: + return True + return False + + +ZHIPUAI_API_KEY = "4a1f42142b2adfd2bdcbb864eb8ddf8f.6o0LAkhsANjQHGlJ" +llm = ChatZhipuAI( + temperature=0.1, + api_key=ZHIPUAI_API_KEY, + model_name="glm-4", +) +prompt = ChatPromptTemplate.from_template("""仅根据所提供的上下文回答以下问题,并且禁止添加自己认为的信息,如果上下文没有,输出暂无结果: + +{context} + + +问题: {input} +""") + +output_parser = StrOutputParser() +embeddings = HuggingFaceEmbeddings(model_name=huggingface_path) +text_splitter = RecursiveCharacterTextSplitter() + + +def read_pdf(file_path): + text = "" + with pdfplumber.open(file_path) as pdf: + for page in pdf.pages: + text += page.extract_text() or "" + return text + + +def original(data: str) -> str: + prompt_none = ChatPromptTemplate.from_messages([ + ("system", "您是一个助手,请回答用户的问题。"), + ("user", "{input}") + ]) + output_parser_original = StrOutputParser() + chain = prompt_none | llm | output_parser_original + response = chain.invoke({"input": data}) # 调用链条,并传入数据 + return response + + + +def process_uploaded_file(uploaded_file, question): + # 生成原始模型输出 + original_answer = original(question) + + file_path = uploaded_file.name + if file_path.endswith(".pdf"): + text = read_pdf(file_path) + else: + loader = TextLoader(file_path) + text = loader.load()[0].page_content + # 文本分割和嵌入处理 + documents = text_splitter.split_text(text) + vector = FAISS.from_texts(documents, embeddings) + retriever = vector.as_retriever() # 返回相关片段 + document_chain = create_stuff_documents_chain(llm, prompt) + retrieval_chain = create_retrieval_chain(retriever, document_chain) + response = retrieval_chain.invoke({"input": question}) + # 获取问题的嵌入向量 + question_embedding = embeddings.embed_query(question) + # 获取所有文档的嵌入向量 + document_embeddings = embeddings.embed_documents(documents) + # 计算每个文档与问题之间的余弦相似度 + scores = cosine_similarity([question_embedding], document_embeddings)[0] + + # 打印所有文档的分数以进行调试 + print("所有文档的相似度得分:") + for i, (doc, score) in enumerate(zip(documents, scores)): + print(f"段落 {i + 1}:") + print(f"内容: {doc[:100]}...") # 打印前100个字符,避免输出过多 + print(f"得分: {score}") + print("-" * 40) # 分隔线,方便阅读 + + # 筛选出得分大于等于 0.1 的文档(暂时降低阈值以确保有输出) + high_score_docs = [ + (doc, score) for doc, score in zip(documents, scores) if score >= 0.3 + ] + + # 如果没有高分文档,返回默认信息 + if not high_score_docs: + return original_answer, "未找到相关信息", "未找到相关上下文" + + # 提取高分文档上下文内容 + context = "\n\n".join([doc for doc, _ in high_score_docs[:3]]) + + # 使用上下文调用模型 + + # 检查回答内容与置信度 + answer = response.get("answer", "").strip() + confidence = response.get("confidence", 1.0) + + # 如果无答案或置信度低于阈值,返回默认信息 + if not answer or confidence < 0.5: + answer = "未找到相关信息" + + return original_answer, answer, context + + +def predict(history): + stop = StopOnTokens() + messages = [] + + # 在模型生成响应之前输出历史记录 + for idx, (user_msg, model_msg) in enumerate(history): + # 在处理 history 时,确保每个条目是 [user_message, model_message] 的格式 + if idx == len(history) - 1 and not model_msg: + messages.append({"role": "user", "content": user_msg}) + break + if user_msg: + messages.append({"role": "user", "content": user_msg}) + if model_msg: + messages.append({"role": "assistant", "content": model_msg}) + + # 如果 messages 为空,直接返回 + if not messages: + return history + + model_inputs = tokenizer.apply_chat_template(messages, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt").to(next(model.parameters()).device) + streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) + generate_kwargs = { + "input_ids": model_inputs, + "streamer": streamer, + "max_new_tokens": 8192, # 固定的 max_length 值 + "do_sample": True, + "top_p": 0.7, # 固定的 top_p 值 + "temperature": 0.9, # 固定的 temperature 值 + "stopping_criteria": StoppingCriteriaList([stop]), + "repetition_penalty": 1.2, + "eos_token_id": model.config.eos_token_id, + } + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + for new_token in streamer: + # 确保 history 的最后一项是 [user_message, model_message] 的格式 + if len(history[-1]) < 2: + history[-1].append(new_token) # 添加模型的响应 + else: + history[-1][1] += new_token # 如果已经有模型响应,继续追加 + + # 输出生成新 token 后的历史记录 + yield history + + +# 图片分类使用Resnet进行图片分类 +# path:/home/wangty/lsx/temp-lsx/word_language_model/recog_image/best_resnet18_model.pth +model_resnet = models.resnet18(weights=None) +model_resnet = nn.DataParallel(model_resnet) +num_classes = 7 +model_resnet.module.fc = nn.Linear(model_resnet.module.fc.in_features, num_classes) +model_resnet.load_state_dict(torch.load('/home/wangty/lsx/temp-lsx/recog_image/code/best_resnet18_model.pth', map_location=torch.device('cuda'))) +model_resnet = model_resnet.to('cuda:0') +model_resnet.eval() + +# 数据预处理 +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +class_names = {0: 'akiec', 1: 'bcc', 2: 'bkl', 3: 'df', 4: 'mel', 5: 'nv', 6: 'vasc'} +translations = { # 翻译映射 + "akiec": "光化性角化病和皮内癌 / 鲍温病", + "bcc": "基底细胞癌", + "bkl": "良性角化样病变", + "df": "皮肤纤维瘤", + "mel": "黑色素瘤", + "nv": "黑素细胞痣", + "vasc": "血管病变", +} + + +def translate(name): + translated_text = " ".join([translations.get(word, word) for word in name.split()]) + return translated_text + + +def predict_image(image): + # 检查并确保图像是 RGB 模式 + if not image.mode == 'RGB': + image = image.convert('RGB') # 转换到 RGB 模式 + + # 应用图像预处理 + image = transform(image).unsqueeze(0) # 转换成批次格式并增加维度 + + # 确保模型是图像分类模型 + model_resnet.eval() + with torch.no_grad(): + outputs = model_resnet(image) # 传入模型并获取输出 + _, predicted = torch.max(outputs, 1) + predicted_class = class_names[predicted.item()] + + name = translate(predicted_class) + # prompt_image = ChatPromptTemplate.from_messages([ + # ("system", "你是一个专业的医疗专家,解释问题是什么,该如何治疗。文本开头加入根据您的图片分析等字样"), + # ("user", "{input}") + # ]) + # output_parser_original = StrOutputParser() + # chain = prompt_image | llm | output_parser_original + # response = chain.invoke({"input": name}) # 调用链条,并传入数据 + stop = StopOnTokens() + messages = [{"role": "user", "content": name}] # 只使用用户输入的消息 + + # 处理消息以生成模型输入2 + model_inputs = tokenizer.apply_chat_template(messages, + add_generation_prompt=True, + tokenize=True, + return_tensors="pt").to(next(model.parameters()).device) + streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True) + generate_kwargs = { + "input_ids": model_inputs, + "streamer": streamer, + "max_new_tokens": 8192, # 固定的 max_length 值 + "do_sample": True, + "top_p": 0.7, # 固定的 top_p 值 + "temperature": 0.9, # 固定的 temperature 值 + "stopping_criteria": StoppingCriteriaList([stop]), + "repetition_penalty": 1.2, + "eos_token_id": model.config.eos_token_id, + } + + t = Thread(target=model.generate, kwargs=generate_kwargs) + t.start() + + model_reply = "" + for new_token in streamer: + model_reply += new_token # 收集模型的回复 + + return name,model_reply # 返回最终的模型回复 + + + +# 构建Gradio界面 +with gr.Blocks(css=custom_css) as demo: + with gr.Tabs(): + with gr.Tab("对话问诊"): + # 添加HTML标题 + gr.HTML(html_content) + with gr.Row(): + with gr.Column(): + with gr.Group(elem_id="example-container"): + example_buttons = [gr.Button(example, elem_id="example-btn") for example in examples] + with gr.Column(scale=10): + chatbot = gr.Chatbot(elem_id="gr-chatbot") + with gr.Row(): + user_input = gr.Textbox(show_label=False, placeholder="请输入咨询的问题:", scale=20, + elem_id="textbox") + + submitBtn = gr.Button("发送", elem_id="send-btn") + + + def user(query, history): + # 确保 history 中每一项都是长度为 2 的列表 + if history is None: + history = [] + history.append([query, ""]) # 第二个元素是空字符串,表示初始没有模型回复 + return "", history + + + def send_example(example, history): # 处理用户点击示例按钮后自动发送 + if history is None: + history = [] + history.append([example, ""]) # 将示例作为输入,模型响应初始化为空 + return history + + submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then( + predict, [chatbot], chatbot + ) + + # 修改点击事件,点击时直接调用user和predict函数 + for example_button, example_text in zip(example_buttons, examples): + example_button.click( + send_example, + inputs=[gr.State(value=example_text), chatbot], + outputs=chatbot + ).then( + predict, + inputs=[chatbot], + outputs=chatbot + ) + with gr.Tab("搜索专业文档"): + gr.Markdown("## 文件上传和模型问答") + with gr.Row(): + with gr.Column(): + file_input = gr.File(label="上传文件") + with gr.Column(): + question_input = gr.Textbox(label="请输入问题",lines=8) + with gr.Row(): + original_output = gr.Textbox(label="原模型输出") + output = gr.Textbox(label="模型回答") + with gr.Row(): + context_output = gr.Textbox(label="相关上下文",scale=28) + submit_button = gr.Button("提交") + submit_button.click(process_uploaded_file, inputs=[file_input, question_input], + outputs=[original_output, output, context_output]) + with gr.Tab("上传患病处图像"): + gr.Markdown("## 图像分类") + with gr.Row(): + image_input = gr.Image(type="pil", label="上传患病处图像") + result_output = gr.Textbox(label="知识输出", lines=9) + with gr.Row(): + image_output = gr.Textbox(label="识别结果", scale=28) + classify_button = gr.Button("识别") + classify_button.click(predict_image, inputs=image_input, outputs=[image_output, result_output]) + + +demo.queue() +demo.launch(share=True)