import inspect import json import math import os import re import sys import time from collections import defaultdict from itertools import chain from typing import Any from typing import Dict from typing import List from typing import Union import bmtrain as bmt import numpy as np import torch from bmtrain import nccl from bmtrain.global_var import config as bmt_config sys.path.append("../../") from fm9g.arguments import get_args from fm9g.dragonfly.modeling_dragonfly import Dragonfly from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset from fm9g.utils import exporter from fm9g.utils import logger from fm9g.utils.exporter import save_every_step_stats from fm9g.utils.training_stats import num_non_embedding_parameters from fm9g.utils.training_stats import num_parameters from apps.fm9g_2b.pretrain_dragonfly import initialize import argparse def get_tokenizer(args): from transformers import LlamaTokenizerFast tokenizer = LlamaTokenizerFast(vocab_file=args) return tokenizer def get_model(args): config = DragonflyConfig.from_json_file(args.model_config) config.tp = 1 if args.tp_size != 1 else 0 # TODO config.pose_prob = args.pose_prob config.pose_scaling_factor = args.pose_scaling_factor config.rope_scaling_type = args.rope_scaling_type config.rope_scaling_factor = args.rope_scaling_factor config.orig_max_length = args.orig_max_length bmt.print_rank("model config: {}".format(config)) bmt.print_rank("bmt config: {}".format(bmt.config)) model = Dragonfly(config) if args.load is not None: bmt.print_rank("args.load is not None, start to load checkpoints" + args.load) exporter.load_model_ckpt(args, model) else: bmt.print_rank("args.load is None, start to initialize parameters") bmt.init_parameters(model) return model if __name__ == '__main__': parser = argparse.ArgumentParser() group = parser.add_argument_group("model", "model configuration") group.add_argument("--model-config", type=str,default="apps/fm9g_2b/model_configs/2.4b.json", help="model configuration file") group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm") group.add_argument("--load", type=str, default="/root/autodl-tmp/fm9g_2b_hf_models",help="Path to a directory containing a model checkpoint.") group.add_argument("--tp-size", default=1, type=int) group = parser.add_argument_group("long_context_extend", "long context extend configurations") group.add_argument("--pose_prob", default=0.0, type=float, help="Sample-level PoSE probability") group.add_argument("--pose_scaling_factor",default=1.0, type=float,help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor") group.add_argument("--rope_scaling_type",default="", type=str, choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],help="Context scaling type") group.add_argument("--rope_scaling_factor", default=1, type=int, help="Context scaling factor") group.add_argument("--orig_max_length", default=8192, type=int, help="Original context length before context extending") args = parser.parse_args() bmt.init_distributed(seed=42, tp_size=1) bmt.synchronize() '''tokenizer''' tokenizer_path = "apps/fm9g_2b/tokenizer/tokenizer.model" tokenizer = get_tokenizer(tokenizer_path) '''model''' model = get_model(args) '''inference'''