from __future__ import annotations import os from PIL import Image import torch.nn as nn from torchvision import models, transforms from pathlib import Path from threading import Thread from typing import Union import gradio as gr import torch from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoModel ) from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_community.document_loaders import WebBaseLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_community.document_loaders import TextLoader from langchain.chains import create_retrieval_chain from langchain.tools.retriever import create_retriever_tool import asyncio import numpy as np from sklearn.metrics.pairwise import cosine_similarity import logging import pdfplumber from functools import partial from importlib.metadata import version from typing import ( Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Type, Union, ) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import ( BaseChatModel, generate_from_stream, ) from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ChatMessage, ChatMessageChunk, HumanMessage, HumanMessageChunk, SystemMessage, SystemMessageChunk, ToolMessage, ToolMessageChunk, ) from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, ChatResult, ) from langchain_core.pydantic_v1 import BaseModel, Field from packaging.version import parse logger = logging.getLogger(__name__) huggingface_path = "/home/wangty/lsx/temp-lsx/word_language_model/models--sentence-transformers--all-mpnet-base-v2" """ZhipuAI与langchain部分不兼容,重写langchain部分代码""" def is_zhipu_v2() -> bool: """Return whether zhipu API is v2 or more.""" _version = parse(version("zhipuai")) return _version.major >= 2 def _create_retry_decorator( llm: ChatZhipuAI, run_manager: Optional[ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: import zhipuai errors = [ zhipuai.ZhipuAIError, zhipuai.APIStatusError, zhipuai.APIRequestFailedError, zhipuai.APIReachLimitError, zhipuai.APIInternalError, zhipuai.APIServerFlowExceedError, zhipuai.APIResponseError, zhipuai.APIResponseValidationError, zhipuai.APITimeoutError, ] return create_base_retry_decorator( error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) def convert_message_to_dict(message: BaseMessage) -> dict: """Convert a LangChain message to a dictionary. Args: message: The LangChain message. Returns: The dictionary. """ message_dict: Dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} if "tool_calls" in message.additional_kwargs: message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] # If tool calls only, content is None not empty string if message_dict["content"] == "": message_dict["content"] = None elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolMessage): message_dict = { "role": "tool", "content": message.content, "tool_call_id": message.tool_call_id, } else: raise TypeError(f"Got unknown type {message}") if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: """Convert a dictionary to a LangChain message. Args: _dict: The dictionary. Returns: The LangChain message. """ role = _dict.get("role") if role == "user": return HumanMessage(content=_dict.get("content", "")) elif role == "assistant": content = _dict.get("content", "") or "" additional_kwargs: Dict = {} if tool_calls := _dict.get("tool_calls"): additional_kwargs["tool_calls"] = tool_calls return AIMessage(content=content, additional_kwargs=additional_kwargs) elif role == "system": return SystemMessage(content=_dict.get("content", "")) elif role == "tool": additional_kwargs = {} if "name" in _dict: additional_kwargs["name"] = _dict["name"] return ToolMessage( content=_dict.get("content", ""), tool_call_id=_dict.get("tool_call_id"), additional_kwargs=additional_kwargs, ) else: return ChatMessage(content=_dict.get("content", ""), role=role) def _convert_delta_to_message_chunk( _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: role = _dict.get("role") content = _dict.get("content") or "" additional_kwargs: Dict = {} if _dict.get("tool_calls"): additional_kwargs["tool_calls"] = _dict["tool_calls"] if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "tool" or default_class == ToolMessageChunk: return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) elif role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) else: return default_class(content=content) class ChatZhipuAI(BaseChatModel): """ `ZHIPU AI` large language chat models API. To use, you should have the ``zhipuai`` python package installed. Example: .. code-block:: python from langchain_community.chat_models import ChatZhipuAI zhipuai_chat = ChatZhipuAI( temperature=0.5, api_key="your-api-key", model_name="glm-3-turbo", ) """ zhipuai: Any zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key") """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided.""" client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field("glm-3-turbo", alias="model") """ Model name to use. -glm-3-turbo: According to the input of natural language instructions to complete a variety of language tasks, it is recommended to use SSE or asynchronous call request interface. -glm-4: According to the input of natural language instructions to complete a variety of language tasks, it is recommended to use SSE or asynchronous call request interface. """ temperature: float = Field(0.95) """ What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot be equal to 0. The larger the value, the more random and creative the output; The smaller the value, the more stable or certain the output will be. You are advised to adjust top_p or temperature parameters based on application scenarios, but do not adjust the two parameters at the same time. """ top_p: float = Field(0.7) """ Another method of sampling temperature is called nuclear sampling. The value ranges from 0.0 to 1.0 and cannot be equal to 0 or 1. The model considers the results with top_p probability quality tokens. For example, 0.1 means that the model decoder only considers tokens from the top 10% probability of the candidate set. You are advised to adjust top_p or temperature parameters based on application scenarios, but do not adjust the two parameters at the same time. """ request_id: Optional[str] = Field(None) """ Parameter transmission by the client must ensure uniqueness; A unique identifier used to distinguish each request, which is generated by default by the platform when the client does not transmit it. """ do_sample: Optional[bool] = Field(True) """ When do_sample is true, the sampling policy is enabled. When do_sample is false, the sampling policy temperature and top_p are disabled """ streaming: bool = Field(False) """Whether to stream the results or not.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" max_tokens: Optional[int] = None """Number of chat completions to generate for each prompt.""" max_retries: int = 2 """Maximum number of retries to make when generating.""" @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{"model_name": self.model_name}, **self._default_params} @property def _llm_type(self) -> str: """Return the type of chat model.""" return "zhipuai" @property def lc_secrets(self) -> Dict[str, str]: return {"zhipuai_api_key": "ZHIPUAI_API_KEY"} @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "zhipuai"] @property def lc_attributes(self) -> Dict[str, Any]: attributes: Dict[str, Any] = {} if self.model_name: attributes["model"] = self.model_name if self.streaming: attributes["streaming"] = self.streaming if self.max_tokens: attributes["max_tokens"] = self.max_tokens return attributes @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling ZhipuAI API.""" params = { "model": self.model_name, "stream": self.streaming, "temperature": self.temperature, "top_p": self.top_p, "do_sample": self.do_sample, **self.model_kwargs, } if self.max_tokens is not None: params["max_tokens"] = self.max_tokens return params @property def _client_params(self) -> Dict[str, Any]: """Get the parameters used for the zhipuai client.""" zhipuai_creds: Dict[str, Any] = { "request_id": self.request_id, } return {**self._default_params, **zhipuai_creds} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) try: from zhipuai import ZhipuAI if not is_zhipu_v2(): raise RuntimeError( "zhipuai package version is too low" "Please install it via 'pip install --upgrade zhipuai'" ) self.client = ZhipuAI( api_key=self.zhipuai_api_key, # 填写您的 APIKey ) except ImportError: raise RuntimeError( "Could not import zhipuai package. " "Please install it via 'pip install zhipuai'" ) def completions(self, **kwargs) -> Any | None: return self.client.chat.completions.create(**kwargs) async def async_completions(self, **kwargs) -> Any: loop = asyncio.get_running_loop() partial_func = partial(self.client.chat.completions.create, **kwargs) response = await loop.run_in_executor( None, partial_func, ) return response async def async_completions_result(self, task_id): loop = asyncio.get_running_loop() response = await loop.run_in_executor( None, self.client.asyncCompletions.retrieve_completion_result, task_id, ) return response def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: generations = [] if not isinstance(response, dict): response = response.dict() for res in response["choices"]: message = convert_dict_to_message(res["message"]) generation_info = dict(finish_reason=res.get("finish_reason")) if "index" in res: generation_info["index"] = res["index"] gen = ChatGeneration( message=message, generation_info=generation_info, ) generations.append(gen) token_usage = response.get("usage", {}) llm_output = { "token_usage": token_usage, "model_name": self.model_name, "task_id": response.get("id", ""), "created_time": response.get("created", ""), } return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = self._client_params if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [convert_message_to_dict(m) for m in messages] return message_dicts, params def completion_with_retry( self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(self, run_manager=run_manager) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: return self.completions(**kwargs) return _completion_with_retry(**kwargs) async def acompletion_with_retry( self, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator(self, run_manager=run_manager) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: return await self.async_completions(**kwargs) return await _completion_with_retry(**kwargs) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: """Generate a chat response.""" should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = { **params, **({"stream": stream} if stream is not None else {}), **kwargs, } response = self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ) return self._create_chat_result(response) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, stream: Optional[bool] = False, **kwargs: Any, ) -> ChatResult: """Asynchronously generate a chat response.""" should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = { **params, **({"stream": stream} if stream is not None else {}), **kwargs, } response = await self.acompletion_with_retry( messages=message_dicts, run_manager=run_manager, **params ) return self._create_chat_result(response) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream the chat response in chunks.""" message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): if not isinstance(chunk, dict): chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) finish_reason = choice.get("finish_reason") generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) yield chunk if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) """问答系统界面""" # 自定义CSS custom_css = """ #gr-chatbot { background-color: #FFFFFF !important; /* 设置Chatbot组件背景为浅灰色 */ color: #000000 !important; /* 设置Chatbot字体颜色为黑色 */ border: 2px solid #ccc; /* 设置Chatbot边框 */ border-radius: 20px; /* 设置圆角边框 */ padding: 10px; height: 1600px; display: block !important; /* 强制为块级元素 */ } #example-btn { width: 250px !important; /* 按钮宽度变窄 */ font-size: 20px !important; /* 按钮字体大小 */ font-family: "仿宋", serif !important; /* 设置字体为宋体 */ padding: 5px !important; /* 调整按钮内边距 */ margin-bottom: 5px; /* 按钮之间的间距 */ background-color: #E4E4E7 !important; border: none !important; /* 去掉按钮的边框 */ } #example-container { border: 2px solid #ccc; /* 设置方框边框 */ padding:10px; /* 增加内边距 */ border-radius: 30px; /* 设置圆角 */ margin-bottom: 10px; /* 与其他元素保持间距 */ width: 285px !important; /* 固定按钮容器宽度 */ height: 460px !important; /* 固定按钮容器高度 */ overflow-y: auto !important; /* 自动显示垂直滚动条 */ overflow-x: hidden !重要; /* 禁用水平滚动条 */ display: block !important; /* 强制为块级元素 */ background-color: #E4E4E7 !important; } #send-btn { width: 60px !important; /* 缩小按钮宽度 */ height: 60px !important; /* 缩小按钮的高度 */ font-size: 20px; /* 调整按钮字体大小 */ font-family: "仿宋", serif !important; /* 设置字体为宋体 */ padding: 5px; /* 设置按钮的内边距 */ vertical-align: middle; /* 按钮内容垂直居中 */ background-color: #FFFFFF !important; /* 按钮背景为白色 */ border: 1px solid #CCCCCC !important; /* 将按钮的边框设置为黑色 */ } #textbox { height: 60px !重要; /* 设置输入框的高度 */ line-height: normal !重要; /* 确保行高正常 */ font-size: 20px !重要; /* 调整输入框的字体大小 */ } """ examples = [ "请问感冒的症状有哪些?", "如何预防高血压?", "治疗糖尿病的常见方法是什么?", "什么是流感?", "什么是偏头痛的常见诱因?", "支气管炎怎么治疗?", "冠心病的早期症状是什么?", "治疗哮喘有哪些常见药物?", "癫痫病如何治疗?", "高血压不能吃什么?" ] # HTML内容:标题的定义 html_content = """