You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.8 KiB

1 year ago
from typing import Any, List, Mapping, Optional
from transformers import AutoTokenizer, AutoModel,AutoConfig
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from args import args
class CustomLLM(LLM):
model_kwargs: Optional[dict] = None
pretrained_model_name = args.pretrained_model_name
top_k = args.topk
temperature = args.temperature
device = args.device
@property
def _llm_type(self) -> str:
return "CustomLLM"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""返回字典,打印出选择的属性"""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
def _call(self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,) -> str:
_model_kwargs = self.model_kwargs or {}
payload = {
"temperature": self.temperature,
"top_k": self.top_k,
}
model_config = AutoConfig.from_pretrained(self.pretrained_model_name, trust_remote_code=True)
model_config.update(payload)
model_config.update(_model_kwargs)
model_config.update(kwargs)
model = AutoModel.from_pretrained(self.pretrained_model_name, config=model_config, trust_remote_code=True).half().to(
self.device)
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, trust_remote_code=True)
model = model.eval()
response, history = model.chat(tokenizer, prompt, history=[])
if stop is not None:
response = enforce_stop_tokens(response, stop)
return response