From dc71fc11f0cab1ce343a23f0ffe333ea1035db98 Mon Sep 17 00:00:00 2001 From: paxflsu4r <198028451@qq.com> Date: Mon, 20 Jan 2025 16:56:54 +0800 Subject: [PATCH] ADD file via upload --- fm9g.py | 211 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 fm9g.py diff --git a/fm9g.py b/fm9g.py new file mode 100644 index 0000000..128a000 --- /dev/null +++ b/fm9g.py @@ -0,0 +1,211 @@ +import io +import json +from typing import Dict +from typing import IO +from typing import List + +import pkg_resources +from pytrie import StringTrie + + +def load_vocab(fp: IO[bytes]) -> Dict[str, int]: + """Loads a vocabulary file into a dictionary.""" + vocab: Dict[str, int] = {} + + reader = io.TextIOWrapper(fp, encoding="utf-8") + for token in reader.readlines(): + token = token.strip() + if len(token) == 0: + continue + token = json.loads(token) + vocab[token] = len(vocab) + return vocab + + +class FM9GTokenizer(object): + def __init__(self, path=None): + self.unk_token = "" + self.bos_token = "" + self.eos_token = "" + self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [ + "<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100) + ] + + self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list) + + if path: + all_tokens = load_vocab(io.FileIO(path, "rb")) + else: + all_tokens = load_vocab(pkg_resources.resource_stream("fm9g", "/fm9g/vocabs/fm9g.txt")) + + self.encoder: Dict[str, int] = {} + self._special_encoder: Dict[str, int] = {} + for token, token_id in all_tokens.items(): + if token in self._special_token_set: + self._special_encoder[token] = token_id + else: + self.encoder[token] = token_id + + self.decoder = {v: k for k, v in self.encoder.items()} + self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)} + + self._max_word_len = max([len(x) for x in self.encoder.keys()]) + + self._len_word_first = {} + for x in self.encoder.keys(): + if not x[0] in self._len_word_first: + self._len_word_first[x[0]] = 1 + if len(x) > self._len_word_first[x[0]]: + self._len_word_first[x[0]] = len(x) + self.tencoder = StringTrie(self.encoder) + + def get_piece(self, text: str) -> str: + if text[0] in self._len_word_first: + text = text[: self._len_word_first[text[0]]] + len_text = len(text) + for i in range(len(text)): + sub = text[: len_text - i] + if sub in self.encoder: + return sub + return text[0] + + @property + def vocab_size(self): + return len(self) + + @property + def eos_id(self): + return self._special_encoder[self.eos_token] + + @property + def bos_id(self): + return self._special_encoder[self.bos_token] + + @property + def unk_id(self): + return self._special_encoder[self.unk_token] + + def __len__(self): + return len(self.encoder) + len(self._special_encoder) + + def tokenize(self, text: str) -> List[str]: + output_tokens: List[str] = [] + st = 0 + while st < len(text): + piece = self.get_piece(text[st:]) + output_tokens.append(piece) + st += len(piece) + return output_tokens + + @staticmethod + def escape(text: str) -> str: + return text + + @staticmethod + def unescape(text: str) -> str: + return text + + def encode(self, text: str) -> List[int]: + #if len(text) > 20480: + # return [0 for _ in range(20480)] + ret = [] + for x in self.tokenize(text): + if x in self.encoder: + ret.append(self.encoder[x]) + else: + ret.extend(self._encode_unicode(x)) + return ret + + def decode(self, tokens: List[int]): + """Decode ids into a string.""" + ret = [] + st = 0 + + while st < len(tokens): + if tokens[st] in self.decoder: + ret.append(self.decoder[tokens[st]]) + st += 1 + elif tokens[st] in self._byte_decoder: + if ( + st + 3 < len(tokens) + and tokens[st + 1] in self._byte_decoder + and tokens[st + 2] in self._byte_decoder + and tokens[st + 3] in self._byte_decoder + ): + first_id = self._byte_decoder[tokens[st]] + plane_id = self._byte_decoder[tokens[st + 1]] + row_id = self._byte_decoder[tokens[st + 2]] + cell_id = self._byte_decoder[tokens[st + 3]] + int_bytes = int.to_bytes(first_id << 24 | plane_id << 16 | row_id << 8 | cell_id, 4, "big") + try: + decoded_str = int_bytes.decode("utf-8", errors="replace") + ret.append(decoded_str) + #print(decoded_str) + except UnicodeDecodeError as e: + print(f"UnicodeDecodeError: {e}") + + st += 4 + elif ( + st + 2 < len(tokens) + and tokens[st + 1] in self._byte_decoder + and tokens[st + 2] in self._byte_decoder + ): + plane_id = self._byte_decoder[tokens[st]] + row_id = self._byte_decoder[tokens[st + 1]] + cell_id = self._byte_decoder[tokens[st + 2]] + int_bytes = int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big") + try: + decoded_str = int_bytes.decode("utf-8", errors="replace") + ret.append(decoded_str) + except UnicodeDecodeError as e: + print(f"UnicodeDecodeError: {e}") + st += 3 + elif st + 1 < len(tokens) and tokens[st + 1] in self._byte_decoder: + row_id = self._byte_decoder[tokens[st]] + cell_id = self._byte_decoder[tokens[st + 1]] + int_bytes = int.to_bytes(row_id << 8 | cell_id, 2, "big") + try: + decoded_str = int_bytes.decode("utf-8", errors="replace") + ret.append(decoded_str) + except UnicodeDecodeError as e: + print(f"UnicodeDecodeError: {e}") + #ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8")) + st += 2 + else: + cell_id = self._byte_decoder[tokens[st]] + int_bytes = int.to_bytes(cell_id, 1, "big") + try: + decoded_str = int_bytes.decode("utf-8", errors="replace") + ret.append(decoded_str) + except UnicodeDecodeError as e: + print(f"UnicodeDecodeError: {e}") + #ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8")) + st += 1 + elif tokens[st] == self.eos_id: + ret.append(self.eos_token) + st += 1 + elif tokens[st] == self.bos_id: + ret.append(self.bos_token) + st += 1 + else: + ret.append(self.unk_token) + st += 1 + return "".join(ret) + + def _encode_unicode(self, token): + # wrap unicode encoding into a helper function + ids = [] + utf8_id = token.encode("utf-8") + for _id in utf8_id: + ids.append(self._special_encoder[self.byte_list[_id]]) + return ids + + def next_token(self, text): + # fast next token matching + token, token_id = self.tencoder.longest_prefix_item(text, (None, None)) + if token is None: + token = text[0] + token_ids = self._encode_unicode(token) + else: + token_ids = [token_id] + return token, token_ids