from typing import Any, List, Mapping, Optional from transformers import AutoTokenizer, AutoModel,AutoConfig from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from args import args class CustomLLM(LLM): model_kwargs: Optional[dict] = None pretrained_model_name = args.pretrained_model_name top_k = args.topk temperature = args.temperature device = args.device @property def _llm_type(self) -> str: return "CustomLLM" @property def _identifying_params(self) -> Mapping[str, Any]: """返回字典,打印出选择的属性""" _model_kwargs = self.model_kwargs or {} return { **{"model_kwargs": _model_kwargs}, } def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any,) -> str: _model_kwargs = self.model_kwargs or {} payload = { "temperature": self.temperature, "top_k": self.top_k, } model_config = AutoConfig.from_pretrained(self.pretrained_model_name, trust_remote_code=True) model_config.update(payload) model_config.update(_model_kwargs) model_config.update(kwargs) model = AutoModel.from_pretrained(self.pretrained_model_name, config=model_config, trust_remote_code=True).half().to( self.device) tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_name, trust_remote_code=True) model = model.eval() response, history = model.chat(tokenizer, prompt, history=[]) if stop is not None: response = enforce_stop_tokens(response, stop) return response