You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.8 KiB
59 lines
1.8 KiB
1 year ago
|
from typing import Any, List, Mapping, Optional
|
||
|
from transformers import AutoTokenizer, AutoModel,AutoConfig
|
||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||
|
from langchain.llms.base import LLM
|
||
|
from langchain.llms.utils import enforce_stop_tokens
|
||
|
from args import args
|
||
|
|
||
|
|
||
|
class CustomLLM(LLM):
|
||
|
model_kwargs: Optional[dict] = None
|
||
|
pretrained_model_name = args.pretrained_model_name
|
||
|
top_k = args.topk
|
||
|
temperature = args.temperature
|
||
|
device = args.device
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
return "CustomLLM"
|
||
|
|
||
|
@property
|
||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||
|
"""返回字典,打印出选择的属性"""
|
||
|
_model_kwargs = self.model_kwargs or {}
|
||
|
return {
|
||
|
**{"model_kwargs": _model_kwargs},
|
||
|
}
|
||
|
|
||
|
def _call(self,
|
||
|
prompt: str,
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,) -> str:
|
||
|
|
||
|
|
||
|
_model_kwargs = self.model_kwargs or {}
|
||
|
|
||
|
payload = {
|
||
|
"temperature": self.temperature,
|
||
|
"top_k": self.top_k,
|
||
|
}
|
||
|
|
||
|
model_config = AutoConfig.from_pretrained(self.pretrained_model_name, trust_remote_code=True)
|
||
|
model_config.update(payload)
|
||
|
model_config.update(_model_kwargs)
|
||
|
model_config.update(kwargs)
|
||
|
|
||
|
|
||
|
model = AutoModel.from_pretrained(self.pretrained_model_name, config=model_config, trust_remote_code=True).half().to(
|
||
|
self.device)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, trust_remote_code=True)
|
||
|
model = model.eval()
|
||
|
response, history = model.chat(tokenizer, prompt, history=[])
|
||
|
|
||
|
if stop is not None:
|
||
|
response = enforce_stop_tokens(response, stop)
|
||
|
|
||
|
return response
|
||
|
|