diff --git a/modelTest.py b/modelTest.py new file mode 100644 index 0000000..35517ee --- /dev/null +++ b/modelTest.py @@ -0,0 +1,84 @@ +import inspect +import json +import math +import os +import re +import sys +import time +from collections import defaultdict +from itertools import chain +from typing import Any +from typing import Dict +from typing import List +from typing import Union + +import bmtrain as bmt +import numpy as np +import torch +from bmtrain import nccl +from bmtrain.global_var import config as bmt_config + +sys.path.append("../../") +from fm9g.arguments import get_args +from fm9g.dragonfly.modeling_dragonfly import Dragonfly +from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig +from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher +from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset +from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset +from fm9g.utils import exporter +from fm9g.utils import logger +from fm9g.utils.exporter import save_every_step_stats +from fm9g.utils.training_stats import num_non_embedding_parameters +from fm9g.utils.training_stats import num_parameters + +from apps.fm9g_2b.pretrain_dragonfly import initialize +import argparse +def get_tokenizer(args): + from transformers import LlamaTokenizerFast + tokenizer = LlamaTokenizerFast(vocab_file=args) + return tokenizer + +def get_model(args): + config = DragonflyConfig.from_json_file(args.model_config) + config.tp = 1 if args.tp_size != 1 else 0 # TODO + config.pose_prob = args.pose_prob + config.pose_scaling_factor = args.pose_scaling_factor + config.rope_scaling_type = args.rope_scaling_type + config.rope_scaling_factor = args.rope_scaling_factor + config.orig_max_length = args.orig_max_length + + bmt.print_rank("model config: {}".format(config)) + bmt.print_rank("bmt config: {}".format(bmt.config)) + + model = Dragonfly(config) + if args.load is not None: + bmt.print_rank("args.load is not None, start to load checkpoints" + args.load) + exporter.load_model_ckpt(args, model) + else: + bmt.print_rank("args.load is None, start to initialize parameters") + bmt.init_parameters(model) + return model + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + group = parser.add_argument_group("model", "model configuration") + group.add_argument("--model-config", type=str,default="apps/fm9g_2b/model_configs/2.4b.json", help="model configuration file") + group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm") + group.add_argument("--load", type=str, default="/root/autodl-tmp/fm9g_2b_hf_models",help="Path to a directory containing a model checkpoint.") + group.add_argument("--tp-size", default=1, type=int) + + group = parser.add_argument_group("long_context_extend", "long context extend configurations") + group.add_argument("--pose_prob", default=0.0, type=float, help="Sample-level PoSE probability") + group.add_argument("--pose_scaling_factor",default=1.0, type=float,help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor") + group.add_argument("--rope_scaling_type",default="", type=str, choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],help="Context scaling type") + group.add_argument("--rope_scaling_factor", default=1, type=int, help="Context scaling factor") + group.add_argument("--orig_max_length", default=8192, type=int, help="Original context length before context extending") + args = parser.parse_args() + bmt.init_distributed(seed=42, tp_size=1) + bmt.synchronize() + '''tokenizer''' + tokenizer_path = "apps/fm9g_2b/tokenizer/tokenizer.model" + tokenizer = get_tokenizer(tokenizer_path) + '''model''' + model = get_model(args) + '''inference'''