|
|
|
@ -0,0 +1,966 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
import os
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torchvision import models, transforms
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from threading import Thread
|
|
|
|
|
from typing import Union
|
|
|
|
|
import gradio as gr
|
|
|
|
|
import torch
|
|
|
|
|
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
|
|
|
|
from transformers import (
|
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
|
AutoTokenizer,
|
|
|
|
|
PreTrainedModel,
|
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
|
PreTrainedTokenizerFast,
|
|
|
|
|
StoppingCriteria,
|
|
|
|
|
StoppingCriteriaList,
|
|
|
|
|
TextIteratorStreamer, AutoModel
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
from langchain_community.document_loaders import WebBaseLoader
|
|
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
|
|
from langchain_community.document_loaders import TextLoader
|
|
|
|
|
from langchain.chains import create_retrieval_chain
|
|
|
|
|
from langchain.tools.retriever import create_retriever_tool
|
|
|
|
|
import asyncio
|
|
|
|
|
import numpy as np
|
|
|
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
import logging
|
|
|
|
|
import pdfplumber
|
|
|
|
|
from functools import partial
|
|
|
|
|
from importlib.metadata import version
|
|
|
|
|
from typing import (
|
|
|
|
|
Any,
|
|
|
|
|
Callable,
|
|
|
|
|
Dict,
|
|
|
|
|
Iterator,
|
|
|
|
|
List,
|
|
|
|
|
Mapping,
|
|
|
|
|
Optional,
|
|
|
|
|
Tuple,
|
|
|
|
|
Type,
|
|
|
|
|
Union,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.language_models.chat_models import (
|
|
|
|
|
BaseChatModel,
|
|
|
|
|
generate_from_stream,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.language_models.llms import create_base_retry_decorator
|
|
|
|
|
from langchain_core.messages import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
AIMessageChunk,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
BaseMessageChunk,
|
|
|
|
|
ChatMessage,
|
|
|
|
|
ChatMessageChunk,
|
|
|
|
|
HumanMessage,
|
|
|
|
|
HumanMessageChunk,
|
|
|
|
|
SystemMessage,
|
|
|
|
|
SystemMessageChunk,
|
|
|
|
|
ToolMessage,
|
|
|
|
|
ToolMessageChunk,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.outputs import (
|
|
|
|
|
ChatGeneration,
|
|
|
|
|
ChatGenerationChunk,
|
|
|
|
|
ChatResult,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
|
from packaging.version import parse
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
huggingface_path = "/home/wangty/lsx/temp-lsx/word_language_model/models--sentence-transformers--all-mpnet-base-v2"
|
|
|
|
|
"""ZhipuAI与langchain部分不兼容,重写langchain部分代码"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_zhipu_v2() -> bool:
|
|
|
|
|
"""Return whether zhipu API is v2 or more."""
|
|
|
|
|
_version = parse(version("zhipuai"))
|
|
|
|
|
return _version.major >= 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_retry_decorator(
|
|
|
|
|
llm: ChatZhipuAI,
|
|
|
|
|
run_manager: Optional[
|
|
|
|
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
|
|
|
|
] = None,
|
|
|
|
|
) -> Callable[[Any], Any]:
|
|
|
|
|
import zhipuai
|
|
|
|
|
|
|
|
|
|
errors = [
|
|
|
|
|
zhipuai.ZhipuAIError,
|
|
|
|
|
zhipuai.APIStatusError,
|
|
|
|
|
zhipuai.APIRequestFailedError,
|
|
|
|
|
zhipuai.APIReachLimitError,
|
|
|
|
|
zhipuai.APIInternalError,
|
|
|
|
|
zhipuai.APIServerFlowExceedError,
|
|
|
|
|
zhipuai.APIResponseError,
|
|
|
|
|
zhipuai.APIResponseValidationError,
|
|
|
|
|
zhipuai.APITimeoutError,
|
|
|
|
|
]
|
|
|
|
|
return create_base_retry_decorator(
|
|
|
|
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
|
|
|
|
"""Convert a LangChain message to a dictionary.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
message: The LangChain message.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The dictionary.
|
|
|
|
|
"""
|
|
|
|
|
message_dict: Dict[str, Any]
|
|
|
|
|
if isinstance(message, ChatMessage):
|
|
|
|
|
message_dict = {"role": message.role, "content": message.content}
|
|
|
|
|
elif isinstance(message, HumanMessage):
|
|
|
|
|
message_dict = {"role": "user", "content": message.content}
|
|
|
|
|
elif isinstance(message, AIMessage):
|
|
|
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
|
|
|
if "tool_calls" in message.additional_kwargs:
|
|
|
|
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
|
|
|
|
# If tool calls only, content is None not empty string
|
|
|
|
|
if message_dict["content"] == "":
|
|
|
|
|
message_dict["content"] = None
|
|
|
|
|
elif isinstance(message, SystemMessage):
|
|
|
|
|
message_dict = {"role": "system", "content": message.content}
|
|
|
|
|
elif isinstance(message, ToolMessage):
|
|
|
|
|
message_dict = {
|
|
|
|
|
"role": "tool",
|
|
|
|
|
"content": message.content,
|
|
|
|
|
"tool_call_id": message.tool_call_id,
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Got unknown type {message}")
|
|
|
|
|
if "name" in message.additional_kwargs:
|
|
|
|
|
message_dict["name"] = message.additional_kwargs["name"]
|
|
|
|
|
return message_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|
|
|
|
"""Convert a dictionary to a LangChain message.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
_dict: The dictionary.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The LangChain message.
|
|
|
|
|
"""
|
|
|
|
|
role = _dict.get("role")
|
|
|
|
|
if role == "user":
|
|
|
|
|
return HumanMessage(content=_dict.get("content", ""))
|
|
|
|
|
elif role == "assistant":
|
|
|
|
|
content = _dict.get("content", "") or ""
|
|
|
|
|
additional_kwargs: Dict = {}
|
|
|
|
|
if tool_calls := _dict.get("tool_calls"):
|
|
|
|
|
additional_kwargs["tool_calls"] = tool_calls
|
|
|
|
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
|
|
|
|
elif role == "system":
|
|
|
|
|
return SystemMessage(content=_dict.get("content", ""))
|
|
|
|
|
elif role == "tool":
|
|
|
|
|
additional_kwargs = {}
|
|
|
|
|
if "name" in _dict:
|
|
|
|
|
additional_kwargs["name"] = _dict["name"]
|
|
|
|
|
return ToolMessage(
|
|
|
|
|
content=_dict.get("content", ""),
|
|
|
|
|
tool_call_id=_dict.get("tool_call_id"),
|
|
|
|
|
additional_kwargs=additional_kwargs,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return ChatMessage(content=_dict.get("content", ""), role=role)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_delta_to_message_chunk(
|
|
|
|
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
|
|
|
|
) -> BaseMessageChunk:
|
|
|
|
|
role = _dict.get("role")
|
|
|
|
|
content = _dict.get("content") or ""
|
|
|
|
|
additional_kwargs: Dict = {}
|
|
|
|
|
if _dict.get("tool_calls"):
|
|
|
|
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
|
|
|
|
|
|
|
|
|
if role == "user" or default_class == HumanMessageChunk:
|
|
|
|
|
return HumanMessageChunk(content=content)
|
|
|
|
|
elif role == "assistant" or default_class == AIMessageChunk:
|
|
|
|
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
|
|
|
|
elif role == "system" or default_class == SystemMessageChunk:
|
|
|
|
|
return SystemMessageChunk(content=content)
|
|
|
|
|
elif role == "tool" or default_class == ToolMessageChunk:
|
|
|
|
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
|
|
|
|
elif role or default_class == ChatMessageChunk:
|
|
|
|
|
return ChatMessageChunk(content=content, role=role)
|
|
|
|
|
else:
|
|
|
|
|
return default_class(content=content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatZhipuAI(BaseChatModel):
|
|
|
|
|
"""
|
|
|
|
|
`ZHIPU AI` large language chat models API.
|
|
|
|
|
|
|
|
|
|
To use, you should have the ``zhipuai`` python package installed.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
from langchain_community.chat_models import ChatZhipuAI
|
|
|
|
|
|
|
|
|
|
zhipuai_chat = ChatZhipuAI(
|
|
|
|
|
temperature=0.5,
|
|
|
|
|
api_key="your-api-key",
|
|
|
|
|
model_name="glm-3-turbo",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
zhipuai: Any
|
|
|
|
|
zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key")
|
|
|
|
|
"""Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided."""
|
|
|
|
|
|
|
|
|
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
|
|
|
|
|
|
|
|
|
model_name: str = Field("glm-3-turbo", alias="model")
|
|
|
|
|
"""
|
|
|
|
|
Model name to use.
|
|
|
|
|
-glm-3-turbo:
|
|
|
|
|
According to the input of natural language instructions to complete a
|
|
|
|
|
variety of language tasks, it is recommended to use SSE or asynchronous
|
|
|
|
|
call request interface.
|
|
|
|
|
-glm-4:
|
|
|
|
|
According to the input of natural language instructions to complete a
|
|
|
|
|
variety of language tasks, it is recommended to use SSE or asynchronous
|
|
|
|
|
call request interface.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
temperature: float = Field(0.95)
|
|
|
|
|
"""
|
|
|
|
|
What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot
|
|
|
|
|
be equal to 0.
|
|
|
|
|
The larger the value, the more random and creative the output; The smaller
|
|
|
|
|
the value, the more stable or certain the output will be.
|
|
|
|
|
You are advised to adjust top_p or temperature parameters based on application
|
|
|
|
|
scenarios, but do not adjust the two parameters at the same time.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
top_p: float = Field(0.7)
|
|
|
|
|
"""
|
|
|
|
|
Another method of sampling temperature is called nuclear sampling. The value
|
|
|
|
|
ranges from 0.0 to 1.0 and cannot be equal to 0 or 1.
|
|
|
|
|
The model considers the results with top_p probability quality tokens.
|
|
|
|
|
For example, 0.1 means that the model decoder only considers tokens from the
|
|
|
|
|
top 10% probability of the candidate set.
|
|
|
|
|
You are advised to adjust top_p or temperature parameters based on application
|
|
|
|
|
scenarios, but do not adjust the two parameters at the same time.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
request_id: Optional[str] = Field(None)
|
|
|
|
|
"""
|
|
|
|
|
Parameter transmission by the client must ensure uniqueness; A unique
|
|
|
|
|
identifier used to distinguish each request, which is generated by default
|
|
|
|
|
by the platform when the client does not transmit it.
|
|
|
|
|
"""
|
|
|
|
|
do_sample: Optional[bool] = Field(True)
|
|
|
|
|
"""
|
|
|
|
|
When do_sample is true, the sampling policy is enabled. When do_sample is false,
|
|
|
|
|
the sampling policy temperature and top_p are disabled
|
|
|
|
|
"""
|
|
|
|
|
streaming: bool = Field(False)
|
|
|
|
|
"""Whether to stream the results or not."""
|
|
|
|
|
|
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
|
|
|
|
|
|
|
|
|
max_tokens: Optional[int] = None
|
|
|
|
|
"""Number of chat completions to generate for each prompt."""
|
|
|
|
|
|
|
|
|
|
max_retries: int = 2
|
|
|
|
|
"""Maximum number of retries to make when generating."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get the identifying parameters."""
|
|
|
|
|
return {**{"model_name": self.model_name}, **self._default_params}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
|
"""Return the type of chat model."""
|
|
|
|
|
return "zhipuai"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
|
|
|
return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
|
|
|
"""Get the namespace of the langchain object."""
|
|
|
|
|
return ["langchain", "chat_models", "zhipuai"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lc_attributes(self) -> Dict[str, Any]:
|
|
|
|
|
attributes: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
if self.model_name:
|
|
|
|
|
attributes["model"] = self.model_name
|
|
|
|
|
|
|
|
|
|
if self.streaming:
|
|
|
|
|
attributes["streaming"] = self.streaming
|
|
|
|
|
|
|
|
|
|
if self.max_tokens:
|
|
|
|
|
attributes["max_tokens"] = self.max_tokens
|
|
|
|
|
|
|
|
|
|
return attributes
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get the default parameters for calling ZhipuAI API."""
|
|
|
|
|
params = {
|
|
|
|
|
"model": self.model_name,
|
|
|
|
|
"stream": self.streaming,
|
|
|
|
|
"temperature": self.temperature,
|
|
|
|
|
"top_p": self.top_p,
|
|
|
|
|
"do_sample": self.do_sample,
|
|
|
|
|
**self.model_kwargs,
|
|
|
|
|
}
|
|
|
|
|
if self.max_tokens is not None:
|
|
|
|
|
params["max_tokens"] = self.max_tokens
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _client_params(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get the parameters used for the zhipuai client."""
|
|
|
|
|
zhipuai_creds: Dict[str, Any] = {
|
|
|
|
|
"request_id": self.request_id,
|
|
|
|
|
}
|
|
|
|
|
return {**self._default_params, **zhipuai_creds}
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
try:
|
|
|
|
|
from zhipuai import ZhipuAI
|
|
|
|
|
|
|
|
|
|
if not is_zhipu_v2():
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"zhipuai package version is too low"
|
|
|
|
|
"Please install it via 'pip install --upgrade zhipuai'"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.client = ZhipuAI(
|
|
|
|
|
api_key=self.zhipuai_api_key, # 填写您的 APIKey
|
|
|
|
|
)
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Could not import zhipuai package. "
|
|
|
|
|
"Please install it via 'pip install zhipuai'"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def completions(self, **kwargs) -> Any | None:
|
|
|
|
|
return self.client.chat.completions.create(**kwargs)
|
|
|
|
|
|
|
|
|
|
async def async_completions(self, **kwargs) -> Any:
|
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
partial_func = partial(self.client.chat.completions.create, **kwargs)
|
|
|
|
|
response = await loop.run_in_executor(
|
|
|
|
|
None,
|
|
|
|
|
partial_func,
|
|
|
|
|
)
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
async def async_completions_result(self, task_id):
|
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
|
response = await loop.run_in_executor(
|
|
|
|
|
None,
|
|
|
|
|
self.client.asyncCompletions.retrieve_completion_result,
|
|
|
|
|
task_id,
|
|
|
|
|
)
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
|
|
|
|
generations = []
|
|
|
|
|
if not isinstance(response, dict):
|
|
|
|
|
response = response.dict()
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
message = convert_dict_to_message(res["message"])
|
|
|
|
|
generation_info = dict(finish_reason=res.get("finish_reason"))
|
|
|
|
|
if "index" in res:
|
|
|
|
|
generation_info["index"] = res["index"]
|
|
|
|
|
gen = ChatGeneration(
|
|
|
|
|
message=message,
|
|
|
|
|
generation_info=generation_info,
|
|
|
|
|
)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
token_usage = response.get("usage", {})
|
|
|
|
|
llm_output = {
|
|
|
|
|
"token_usage": token_usage,
|
|
|
|
|
"model_name": self.model_name,
|
|
|
|
|
"task_id": response.get("id", ""),
|
|
|
|
|
"created_time": response.get("created", ""),
|
|
|
|
|
}
|
|
|
|
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
|
|
|
|
|
|
|
def _create_message_dicts(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
|
|
|
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
|
|
|
|
params = self._client_params
|
|
|
|
|
if stop is not None:
|
|
|
|
|
if "stop" in params:
|
|
|
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
|
|
|
params["stop"] = stop
|
|
|
|
|
message_dicts = [convert_message_to_dict(m) for m in messages]
|
|
|
|
|
return message_dicts, params
|
|
|
|
|
|
|
|
|
|
def completion_with_retry(
|
|
|
|
|
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Use tenacity to retry the completion call."""
|
|
|
|
|
|
|
|
|
|
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
|
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
|
|
|
|
return self.completions(**kwargs)
|
|
|
|
|
|
|
|
|
|
return _completion_with_retry(**kwargs)
|
|
|
|
|
|
|
|
|
|
async def acompletion_with_retry(
|
|
|
|
|
self,
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""Use tenacity to retry the async completion call."""
|
|
|
|
|
|
|
|
|
|
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
|
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
|
|
|
|
return await self.async_completions(**kwargs)
|
|
|
|
|
|
|
|
|
|
return await _completion_with_retry(**kwargs)
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
stream: Optional[bool] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
"""Generate a chat response."""
|
|
|
|
|
|
|
|
|
|
should_stream = stream if stream is not None else self.streaming
|
|
|
|
|
if should_stream:
|
|
|
|
|
stream_iter = self._stream(
|
|
|
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return generate_from_stream(stream_iter)
|
|
|
|
|
|
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
|
params = {
|
|
|
|
|
**params,
|
|
|
|
|
**({"stream": stream} if stream is not None else {}),
|
|
|
|
|
**kwargs,
|
|
|
|
|
}
|
|
|
|
|
response = self.completion_with_retry(
|
|
|
|
|
messages=message_dicts, run_manager=run_manager, **params
|
|
|
|
|
)
|
|
|
|
|
return self._create_chat_result(response)
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
stream: Optional[bool] = False,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
"""Asynchronously generate a chat response."""
|
|
|
|
|
should_stream = stream if stream is not None else self.streaming
|
|
|
|
|
if should_stream:
|
|
|
|
|
stream_iter = self._astream(
|
|
|
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return generate_from_stream(stream_iter)
|
|
|
|
|
|
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
|
params = {
|
|
|
|
|
**params,
|
|
|
|
|
**({"stream": stream} if stream is not None else {}),
|
|
|
|
|
**kwargs,
|
|
|
|
|
}
|
|
|
|
|
response = await self.acompletion_with_retry(
|
|
|
|
|
messages=message_dicts, run_manager=run_manager, **params
|
|
|
|
|
)
|
|
|
|
|
return self._create_chat_result(response)
|
|
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
|
|
|
"""Stream the chat response in chunks."""
|
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
|
|
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
|
for chunk in self.completion_with_retry(
|
|
|
|
|
messages=message_dicts, run_manager=run_manager, **params
|
|
|
|
|
):
|
|
|
|
|
if not isinstance(chunk, dict):
|
|
|
|
|
chunk = chunk.dict()
|
|
|
|
|
if len(chunk["choices"]) == 0:
|
|
|
|
|
continue
|
|
|
|
|
choice = chunk["choices"][0]
|
|
|
|
|
chunk = _convert_delta_to_message_chunk(
|
|
|
|
|
choice["delta"], default_chunk_class
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
finish_reason = choice.get("finish_reason")
|
|
|
|
|
generation_info = (
|
|
|
|
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
|
|
|
|
)
|
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
|
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
|
|
|
|
yield chunk
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""问答系统界面"""
|
|
|
|
|
# 自定义CSS
|
|
|
|
|
custom_css = """
|
|
|
|
|
#gr-chatbot {
|
|
|
|
|
background-color: #FFFFFF !important; /* 设置Chatbot组件背景为浅灰色 */
|
|
|
|
|
color: #000000 !important; /* 设置Chatbot字体颜色为黑色 */
|
|
|
|
|
border: 2px solid #ccc; /* 设置Chatbot边框 */
|
|
|
|
|
border-radius: 20px; /* 设置圆角边框 */
|
|
|
|
|
padding: 10px;
|
|
|
|
|
height: 1600px;
|
|
|
|
|
display: block !important; /* 强制为块级元素 */
|
|
|
|
|
}
|
|
|
|
|
#example-btn {
|
|
|
|
|
|
|
|
|
|
width: 250px !important; /* 按钮宽度变窄 */
|
|
|
|
|
font-size: 20px !important; /* 按钮字体大小 */
|
|
|
|
|
font-family: "仿宋", serif !important; /* 设置字体为宋体 */
|
|
|
|
|
padding: 5px !important; /* 调整按钮内边距 */
|
|
|
|
|
margin-bottom: 5px; /* 按钮之间的间距 */
|
|
|
|
|
background-color: #E4E4E7 !important;
|
|
|
|
|
border: none !important; /* 去掉按钮的边框 */
|
|
|
|
|
}
|
|
|
|
|
#example-container {
|
|
|
|
|
border: 2px solid #ccc; /* 设置方框边框 */
|
|
|
|
|
padding:10px; /* 增加内边距 */
|
|
|
|
|
border-radius: 30px; /* 设置圆角 */
|
|
|
|
|
margin-bottom: 10px; /* 与其他元素保持间距 */
|
|
|
|
|
width: 285px !important; /* 固定按钮容器宽度 */
|
|
|
|
|
height: 460px !important; /* 固定按钮容器高度 */
|
|
|
|
|
overflow-y: auto !important; /* 自动显示垂直滚动条 */
|
|
|
|
|
overflow-x: hidden !重要; /* 禁用水平滚动条 */
|
|
|
|
|
display: block !important; /* 强制为块级元素 */
|
|
|
|
|
background-color: #E4E4E7 !important;
|
|
|
|
|
}
|
|
|
|
|
#send-btn {
|
|
|
|
|
width: 60px !important; /* 缩小按钮宽度 */
|
|
|
|
|
height: 60px !important; /* 缩小按钮的高度 */
|
|
|
|
|
font-size: 20px; /* 调整按钮字体大小 */
|
|
|
|
|
font-family: "仿宋", serif !important; /* 设置字体为宋体 */
|
|
|
|
|
padding: 5px; /* 设置按钮的内边距 */
|
|
|
|
|
vertical-align: middle; /* 按钮内容垂直居中 */
|
|
|
|
|
background-color: #FFFFFF !important; /* 按钮背景为白色 */
|
|
|
|
|
border: 1px solid #CCCCCC !important; /* 将按钮的边框设置为黑色 */
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#textbox {
|
|
|
|
|
height: 60px !重要; /* 设置输入框的高度 */
|
|
|
|
|
line-height: normal !重要; /* 确保行高正常 */
|
|
|
|
|
font-size: 20px !重要; /* 调整输入框的字体大小 */
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
examples = [
|
|
|
|
|
"请问感冒的症状有哪些?",
|
|
|
|
|
"如何预防高血压?",
|
|
|
|
|
"治疗糖尿病的常见方法是什么?",
|
|
|
|
|
"什么是流感?",
|
|
|
|
|
"什么是偏头痛的常见诱因?",
|
|
|
|
|
"支气管炎怎么治疗?",
|
|
|
|
|
"冠心病的早期症状是什么?",
|
|
|
|
|
"治疗哮喘有哪些常见药物?",
|
|
|
|
|
"癫痫病如何治疗?",
|
|
|
|
|
"高血压不能吃什么?"
|
|
|
|
|
]
|
|
|
|
|
# HTML内容:标题的定义
|
|
|
|
|
html_content = """
|
|
|
|
|
<h1 style="text-align: center;">医疗聊天机器人</h1>
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
|
|
|
|
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = os.environ.get('MODEL_PATH', '/home/wangty/lsx/temp-lsx/word_language_model/medical_demo')
|
|
|
|
|
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_path(path: Union[str, Path]) -> Path:
|
|
|
|
|
return Path(path).expanduser().resolve()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(
|
|
|
|
|
model_dir: Union[str, Path], trust_remote_code: bool = True
|
|
|
|
|
) -> tuple[ModelType, TokenizerType]:
|
|
|
|
|
model_dir = _resolve_path(model_dir)
|
|
|
|
|
if (model_dir / 'adapter_config.json').exists():
|
|
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained(
|
|
|
|
|
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
|
|
|
|
)
|
|
|
|
|
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
|
|
|
|
else:
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
|
|
|
|
)
|
|
|
|
|
tokenizer_dir = model_dir
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
|
|
|
|
|
)
|
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
"""重写StopTokens,部分停止id不一致"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria):
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
|
stop_ids = model.config.eos_token_id
|
|
|
|
|
for stop_id in stop_ids:
|
|
|
|
|
if input_ids[0][-1] == stop_id:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ZHIPUAI_API_KEY = "4a1f42142b2adfd2bdcbb864eb8ddf8f.6o0LAkhsANjQHGlJ"
|
|
|
|
|
llm = ChatZhipuAI(
|
|
|
|
|
temperature=0.1,
|
|
|
|
|
api_key=ZHIPUAI_API_KEY,
|
|
|
|
|
model_name="glm-4",
|
|
|
|
|
)
|
|
|
|
|
prompt = ChatPromptTemplate.from_template("""仅根据所提供的上下文回答以下问题,并且禁止添加自己认为的信息,如果上下文没有,输出暂无结果:
|
|
|
|
|
<context>
|
|
|
|
|
{context}
|
|
|
|
|
</context>
|
|
|
|
|
|
|
|
|
|
问题: {input}
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
output_parser = StrOutputParser()
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=huggingface_path)
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_pdf(file_path):
|
|
|
|
|
text = ""
|
|
|
|
|
with pdfplumber.open(file_path) as pdf:
|
|
|
|
|
for page in pdf.pages:
|
|
|
|
|
text += page.extract_text() or ""
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def original(data: str) -> str:
|
|
|
|
|
prompt_none = ChatPromptTemplate.from_messages([
|
|
|
|
|
("system", "您是一个助手,请回答用户的问题。"),
|
|
|
|
|
("user", "{input}")
|
|
|
|
|
])
|
|
|
|
|
output_parser_original = StrOutputParser()
|
|
|
|
|
chain = prompt_none | llm | output_parser_original
|
|
|
|
|
response = chain.invoke({"input": data}) # 调用链条,并传入数据
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_uploaded_file(uploaded_file, question):
|
|
|
|
|
# 生成原始模型输出
|
|
|
|
|
original_answer = original(question)
|
|
|
|
|
|
|
|
|
|
file_path = uploaded_file.name
|
|
|
|
|
if file_path.endswith(".pdf"):
|
|
|
|
|
text = read_pdf(file_path)
|
|
|
|
|
else:
|
|
|
|
|
loader = TextLoader(file_path)
|
|
|
|
|
text = loader.load()[0].page_content
|
|
|
|
|
# 文本分割和嵌入处理
|
|
|
|
|
documents = text_splitter.split_text(text)
|
|
|
|
|
vector = FAISS.from_texts(documents, embeddings)
|
|
|
|
|
retriever = vector.as_retriever() # 返回相关片段
|
|
|
|
|
document_chain = create_stuff_documents_chain(llm, prompt)
|
|
|
|
|
retrieval_chain = create_retrieval_chain(retriever, document_chain)
|
|
|
|
|
response = retrieval_chain.invoke({"input": question})
|
|
|
|
|
# 获取问题的嵌入向量
|
|
|
|
|
question_embedding = embeddings.embed_query(question)
|
|
|
|
|
# 获取所有文档的嵌入向量
|
|
|
|
|
document_embeddings = embeddings.embed_documents(documents)
|
|
|
|
|
# 计算每个文档与问题之间的余弦相似度
|
|
|
|
|
scores = cosine_similarity([question_embedding], document_embeddings)[0]
|
|
|
|
|
|
|
|
|
|
# 打印所有文档的分数以进行调试
|
|
|
|
|
print("所有文档的相似度得分:")
|
|
|
|
|
for i, (doc, score) in enumerate(zip(documents, scores)):
|
|
|
|
|
print(f"段落 {i + 1}:")
|
|
|
|
|
print(f"内容: {doc[:100]}...") # 打印前100个字符,避免输出过多
|
|
|
|
|
print(f"得分: {score}")
|
|
|
|
|
print("-" * 40) # 分隔线,方便阅读
|
|
|
|
|
|
|
|
|
|
# 筛选出得分大于等于 0.1 的文档(暂时降低阈值以确保有输出)
|
|
|
|
|
high_score_docs = [
|
|
|
|
|
(doc, score) for doc, score in zip(documents, scores) if score >= 0.3
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# 如果没有高分文档,返回默认信息
|
|
|
|
|
if not high_score_docs:
|
|
|
|
|
return original_answer, "未找到相关信息", "未找到相关上下文"
|
|
|
|
|
|
|
|
|
|
# 提取高分文档上下文内容
|
|
|
|
|
context = "\n\n".join([doc for doc, _ in high_score_docs[:3]])
|
|
|
|
|
|
|
|
|
|
# 使用上下文调用模型
|
|
|
|
|
|
|
|
|
|
# 检查回答内容与置信度
|
|
|
|
|
answer = response.get("answer", "").strip()
|
|
|
|
|
confidence = response.get("confidence", 1.0)
|
|
|
|
|
|
|
|
|
|
# 如果无答案或置信度低于阈值,返回默认信息
|
|
|
|
|
if not answer or confidence < 0.5:
|
|
|
|
|
answer = "未找到相关信息"
|
|
|
|
|
|
|
|
|
|
return original_answer, answer, context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(history):
|
|
|
|
|
stop = StopOnTokens()
|
|
|
|
|
messages = []
|
|
|
|
|
|
|
|
|
|
# 在模型生成响应之前输出历史记录
|
|
|
|
|
for idx, (user_msg, model_msg) in enumerate(history):
|
|
|
|
|
# 在处理 history 时,确保每个条目是 [user_message, model_message] 的格式
|
|
|
|
|
if idx == len(history) - 1 and not model_msg:
|
|
|
|
|
messages.append({"role": "user", "content": user_msg})
|
|
|
|
|
break
|
|
|
|
|
if user_msg:
|
|
|
|
|
messages.append({"role": "user", "content": user_msg})
|
|
|
|
|
if model_msg:
|
|
|
|
|
messages.append({"role": "assistant", "content": model_msg})
|
|
|
|
|
|
|
|
|
|
# 如果 messages 为空,直接返回
|
|
|
|
|
if not messages:
|
|
|
|
|
return history
|
|
|
|
|
|
|
|
|
|
model_inputs = tokenizer.apply_chat_template(messages,
|
|
|
|
|
add_generation_prompt=True,
|
|
|
|
|
tokenize=True,
|
|
|
|
|
return_tensors="pt").to(next(model.parameters()).device)
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
generate_kwargs = {
|
|
|
|
|
"input_ids": model_inputs,
|
|
|
|
|
"streamer": streamer,
|
|
|
|
|
"max_new_tokens": 8192, # 固定的 max_length 值
|
|
|
|
|
"do_sample": True,
|
|
|
|
|
"top_p": 0.7, # 固定的 top_p 值
|
|
|
|
|
"temperature": 0.9, # 固定的 temperature 值
|
|
|
|
|
"stopping_criteria": StoppingCriteriaList([stop]),
|
|
|
|
|
"repetition_penalty": 1.2,
|
|
|
|
|
"eos_token_id": model.config.eos_token_id,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
|
|
|
t.start()
|
|
|
|
|
|
|
|
|
|
for new_token in streamer:
|
|
|
|
|
# 确保 history 的最后一项是 [user_message, model_message] 的格式
|
|
|
|
|
if len(history[-1]) < 2:
|
|
|
|
|
history[-1].append(new_token) # 添加模型的响应
|
|
|
|
|
else:
|
|
|
|
|
history[-1][1] += new_token # 如果已经有模型响应,继续追加
|
|
|
|
|
|
|
|
|
|
# 输出生成新 token 后的历史记录
|
|
|
|
|
yield history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 图片分类使用Resnet进行图片分类
|
|
|
|
|
# path:/home/wangty/lsx/temp-lsx/word_language_model/recog_image/best_resnet18_model.pth
|
|
|
|
|
model_resnet = models.resnet18(weights=None)
|
|
|
|
|
model_resnet = nn.DataParallel(model_resnet)
|
|
|
|
|
num_classes = 7
|
|
|
|
|
model_resnet.module.fc = nn.Linear(model_resnet.module.fc.in_features, num_classes)
|
|
|
|
|
model_resnet.load_state_dict(torch.load('/home/wangty/lsx/temp-lsx/recog_image/code/best_resnet18_model.pth', map_location=torch.device('cuda')))
|
|
|
|
|
model_resnet = model_resnet.to('cuda:0')
|
|
|
|
|
model_resnet.eval()
|
|
|
|
|
|
|
|
|
|
# 数据预处理
|
|
|
|
|
transform = transforms.Compose([
|
|
|
|
|
transforms.Resize((224, 224)),
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
class_names = {0: 'akiec', 1: 'bcc', 2: 'bkl', 3: 'df', 4: 'mel', 5: 'nv', 6: 'vasc'}
|
|
|
|
|
translations = { # 翻译映射
|
|
|
|
|
"akiec": "光化性角化病和皮内癌 / 鲍温病",
|
|
|
|
|
"bcc": "基底细胞癌",
|
|
|
|
|
"bkl": "良性角化样病变",
|
|
|
|
|
"df": "皮肤纤维瘤",
|
|
|
|
|
"mel": "黑色素瘤",
|
|
|
|
|
"nv": "黑素细胞痣",
|
|
|
|
|
"vasc": "血管病变",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def translate(name):
|
|
|
|
|
translated_text = " ".join([translations.get(word, word) for word in name.split()])
|
|
|
|
|
return translated_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_image(image):
|
|
|
|
|
# 检查并确保图像是 RGB 模式
|
|
|
|
|
if not image.mode == 'RGB':
|
|
|
|
|
image = image.convert('RGB') # 转换到 RGB 模式
|
|
|
|
|
|
|
|
|
|
# 应用图像预处理
|
|
|
|
|
image = transform(image).unsqueeze(0) # 转换成批次格式并增加维度
|
|
|
|
|
|
|
|
|
|
# 确保模型是图像分类模型
|
|
|
|
|
model_resnet.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
outputs = model_resnet(image) # 传入模型并获取输出
|
|
|
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
|
predicted_class = class_names[predicted.item()]
|
|
|
|
|
|
|
|
|
|
name = translate(predicted_class)
|
|
|
|
|
# prompt_image = ChatPromptTemplate.from_messages([
|
|
|
|
|
# ("system", "你是一个专业的医疗专家,解释问题是什么,该如何治疗。文本开头加入根据您的图片分析等字样"),
|
|
|
|
|
# ("user", "{input}")
|
|
|
|
|
# ])
|
|
|
|
|
# output_parser_original = StrOutputParser()
|
|
|
|
|
# chain = prompt_image | llm | output_parser_original
|
|
|
|
|
# response = chain.invoke({"input": name}) # 调用链条,并传入数据
|
|
|
|
|
stop = StopOnTokens()
|
|
|
|
|
messages = [{"role": "user", "content": name}] # 只使用用户输入的消息
|
|
|
|
|
|
|
|
|
|
# 处理消息以生成模型输入2
|
|
|
|
|
model_inputs = tokenizer.apply_chat_template(messages,
|
|
|
|
|
add_generation_prompt=True,
|
|
|
|
|
tokenize=True,
|
|
|
|
|
return_tensors="pt").to(next(model.parameters()).device)
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
generate_kwargs = {
|
|
|
|
|
"input_ids": model_inputs,
|
|
|
|
|
"streamer": streamer,
|
|
|
|
|
"max_new_tokens": 8192, # 固定的 max_length 值
|
|
|
|
|
"do_sample": True,
|
|
|
|
|
"top_p": 0.7, # 固定的 top_p 值
|
|
|
|
|
"temperature": 0.9, # 固定的 temperature 值
|
|
|
|
|
"stopping_criteria": StoppingCriteriaList([stop]),
|
|
|
|
|
"repetition_penalty": 1.2,
|
|
|
|
|
"eos_token_id": model.config.eos_token_id,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
|
|
|
t.start()
|
|
|
|
|
|
|
|
|
|
model_reply = ""
|
|
|
|
|
for new_token in streamer:
|
|
|
|
|
model_reply += new_token # 收集模型的回复
|
|
|
|
|
|
|
|
|
|
return name,model_reply # 返回最终的模型回复
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 构建Gradio界面
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo:
|
|
|
|
|
with gr.Tabs():
|
|
|
|
|
with gr.Tab("对话问诊"):
|
|
|
|
|
# 添加HTML标题
|
|
|
|
|
gr.HTML(html_content)
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column():
|
|
|
|
|
with gr.Group(elem_id="example-container"):
|
|
|
|
|
example_buttons = [gr.Button(example, elem_id="example-btn") for example in examples]
|
|
|
|
|
with gr.Column(scale=10):
|
|
|
|
|
chatbot = gr.Chatbot(elem_id="gr-chatbot")
|
|
|
|
|
with gr.Row():
|
|
|
|
|
user_input = gr.Textbox(show_label=False, placeholder="请输入咨询的问题:", scale=20,
|
|
|
|
|
elem_id="textbox")
|
|
|
|
|
|
|
|
|
|
submitBtn = gr.Button("发送", elem_id="send-btn")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def user(query, history):
|
|
|
|
|
# 确保 history 中每一项都是长度为 2 的列表
|
|
|
|
|
if history is None:
|
|
|
|
|
history = []
|
|
|
|
|
history.append([query, ""]) # 第二个元素是空字符串,表示初始没有模型回复
|
|
|
|
|
return "", history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def send_example(example, history): # 处理用户点击示例按钮后自动发送
|
|
|
|
|
if history is None:
|
|
|
|
|
history = []
|
|
|
|
|
history.append([example, ""]) # 将示例作为输入,模型响应初始化为空
|
|
|
|
|
return history
|
|
|
|
|
|
|
|
|
|
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
|
|
|
|
|
predict, [chatbot], chatbot
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 修改点击事件,点击时直接调用user和predict函数
|
|
|
|
|
for example_button, example_text in zip(example_buttons, examples):
|
|
|
|
|
example_button.click(
|
|
|
|
|
send_example,
|
|
|
|
|
inputs=[gr.State(value=example_text), chatbot],
|
|
|
|
|
outputs=chatbot
|
|
|
|
|
).then(
|
|
|
|
|
predict,
|
|
|
|
|
inputs=[chatbot],
|
|
|
|
|
outputs=chatbot
|
|
|
|
|
)
|
|
|
|
|
with gr.Tab("搜索专业文档"):
|
|
|
|
|
gr.Markdown("## 文件上传和模型问答")
|
|
|
|
|
with gr.Row():
|
|
|
|
|
with gr.Column():
|
|
|
|
|
file_input = gr.File(label="上传文件")
|
|
|
|
|
with gr.Column():
|
|
|
|
|
question_input = gr.Textbox(label="请输入问题",lines=8)
|
|
|
|
|
with gr.Row():
|
|
|
|
|
original_output = gr.Textbox(label="原模型输出")
|
|
|
|
|
output = gr.Textbox(label="模型回答")
|
|
|
|
|
with gr.Row():
|
|
|
|
|
context_output = gr.Textbox(label="相关上下文",scale=28)
|
|
|
|
|
submit_button = gr.Button("提交")
|
|
|
|
|
submit_button.click(process_uploaded_file, inputs=[file_input, question_input],
|
|
|
|
|
outputs=[original_output, output, context_output])
|
|
|
|
|
with gr.Tab("上传患病处图像"):
|
|
|
|
|
gr.Markdown("## 图像分类")
|
|
|
|
|
with gr.Row():
|
|
|
|
|
image_input = gr.Image(type="pil", label="上传患病处图像")
|
|
|
|
|
result_output = gr.Textbox(label="知识输出", lines=9)
|
|
|
|
|
with gr.Row():
|
|
|
|
|
image_output = gr.Textbox(label="识别结果", scale=28)
|
|
|
|
|
classify_button = gr.Button("识别")
|
|
|
|
|
classify_button.click(predict_image, inputs=image_input, outputs=[image_output, result_output])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue()
|
|
|
|
|
demo.launch(share=True)
|