You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
3.7 KiB
100 lines
3.7 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on 2024/9/30 21:30
|
|
@author: Whenxuan Wang
|
|
@email: wwhenxuan@gmail.com
|
|
@url: https://github.com/wwhenxuan/SymTime
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
|
|
from typing import Tuple, Any
|
|
|
|
|
|
class LLM(nn.Module):
|
|
"""Building LLMs as a general interface for symbolic encoders"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm_name: str = "DistilBert",
|
|
llm_layers: int = 6,
|
|
hidden_size: int = 768,
|
|
freeze_layers: int = 3,
|
|
) -> None:
|
|
super(LLM, self).__init__()
|
|
# Get information about the LLM
|
|
self.llm_name = llm_name
|
|
self.llm_layers = llm_layers
|
|
self.freeze_layers = freeze_layers
|
|
|
|
# Get basic config file of LLM
|
|
self.llm_configs, self.llm, self.tokenizer = self.init_llm()
|
|
self.hidden_size = hidden_size
|
|
|
|
# Freeze the first n layers of parameters of LLM
|
|
self.freeze()
|
|
|
|
def forward(
|
|
self, input_ids: torch.Tensor, att_mask: torch.Tensor, labels: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Forward propagation part of LLM"""
|
|
outputs = self.llm(input_ids, att_mask, labels=labels)
|
|
return outputs # The loss can be obtained directly from the output of the model
|
|
|
|
def freeze(self) -> None:
|
|
"""Freeze the first n layers of LLM"""
|
|
for name, param in self.llm.named_parameters():
|
|
for layer_index in range(self.freeze_layers):
|
|
if "layer." + str(layer_index) in name:
|
|
param.requires_grad = False
|
|
|
|
def init_llm(self) -> Tuple[DistilBertConfig, Any, Any]:
|
|
"""Select the LLM to use based on the name of the input model"""
|
|
if self.llm_name == "DistilBert":
|
|
llm_config = DistilBertConfig.from_pretrained(
|
|
# "distilbert/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be"
|
|
"distilbert-base-uncased"
|
|
)
|
|
llm_config.output_attentions = True
|
|
llm_config.output_hidden_states = True
|
|
llm, tokenizer = self.load_DistilBert(config=llm_config)
|
|
else:
|
|
raise ValueError
|
|
return llm_config, llm, tokenizer
|
|
|
|
def load_DistilBert(self, config: DistilBertConfig) -> Tuple[Any, Any]:
|
|
"""Load DistilBert"""
|
|
try:
|
|
# try to load the pretrained model params from local device
|
|
llm = DistilBertForMaskedLM.from_pretrained(
|
|
# "distilbert/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be",
|
|
"distilbert-base-uncased",
|
|
# torch_dtype=torch.float16,
|
|
local_files_only=True,
|
|
config=config,
|
|
)
|
|
except EnvironmentError:
|
|
# try to download the pretrained params
|
|
print(
|
|
f"{self.llm_name} not found locally, trying to load from the network..."
|
|
)
|
|
llm = DistilBertForMaskedLM.from_pretrained(
|
|
"distilbert-base-uncased",
|
|
# torch_dtype=torch.float16,
|
|
local_files_only=False,
|
|
config=config,
|
|
)
|
|
try:
|
|
# try to load the tokenizer from local device
|
|
tokenizer = DistilBertTokenizer.from_pretrained(
|
|
# "distilbert/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be",
|
|
"distilbert-base-uncased",
|
|
local_files_only=True,
|
|
)
|
|
except EnvironmentError:
|
|
# try to download the tokenizer
|
|
tokenizer = DistilBertTokenizer.from_pretrained(
|
|
"distilbert-base-uncased", local_files_only=False
|
|
)
|
|
return llm, tokenizer
|