You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.8 KiB

from typing import Any, List, Mapping, Optional
from transformers import AutoTokenizer, AutoModel,AutoConfig
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from args import args
class CustomLLM(LLM):
model_kwargs: Optional[dict] = None
pretrained_model_name = args.pretrained_model_name
top_k = args.topk
temperature = args.temperature
device = args.device
@property
def _llm_type(self) -> str:
return "CustomLLM"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""返回字典,打印出选择的属性"""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
def _call(self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,) -> str:
_model_kwargs = self.model_kwargs or {}
payload = {
"temperature": self.temperature,
"top_k": self.top_k,
}
model_config = AutoConfig.from_pretrained(self.pretrained_model_name, trust_remote_code=True)
model_config.update(payload)
model_config.update(_model_kwargs)
model_config.update(kwargs)
model = AutoModel.from_pretrained(self.pretrained_model_name, config=model_config, trust_remote_code=True).half().to(
self.device)
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, trust_remote_code=True)
model = model.eval()
response, history = model.chat(tokenizer, prompt, history=[])
if stop is not None:
response = enforce_stop_tokens(response, stop)
return response