ADD file via upload

main
paxflsu4r 7 months ago
parent 7ed7ea5366
commit b5d3efda0e

@ -0,0 +1,84 @@
import inspect
import json
import math
import os
import re
import sys
import time
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import bmtrain as bmt
import numpy as np
import torch
from bmtrain import nccl
from bmtrain.global_var import config as bmt_config
sys.path.append("../../")
from fm9g.arguments import get_args
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
from fm9g.utils import exporter
from fm9g.utils import logger
from fm9g.utils.exporter import save_every_step_stats
from fm9g.utils.training_stats import num_non_embedding_parameters
from fm9g.utils.training_stats import num_parameters
from apps.fm9g_2b.pretrain_dragonfly import initialize
import argparse
def get_tokenizer(args):
from transformers import LlamaTokenizerFast
tokenizer = LlamaTokenizerFast(vocab_file=args)
return tokenizer
def get_model(args):
config = DragonflyConfig.from_json_file(args.model_config)
config.tp = 1 if args.tp_size != 1 else 0 # TODO
config.pose_prob = args.pose_prob
config.pose_scaling_factor = args.pose_scaling_factor
config.rope_scaling_type = args.rope_scaling_type
config.rope_scaling_factor = args.rope_scaling_factor
config.orig_max_length = args.orig_max_length
bmt.print_rank("model config: {}".format(config))
bmt.print_rank("bmt config: {}".format(bmt.config))
model = Dragonfly(config)
if args.load is not None:
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
exporter.load_model_ckpt(args, model)
else:
bmt.print_rank("args.load is None, start to initialize parameters")
bmt.init_parameters(model)
return model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--model-config", type=str,default="apps/fm9g_2b/model_configs/2.4b.json", help="model configuration file")
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
group.add_argument("--load", type=str, default="/root/autodl-tmp/fm9g_2b_hf_models",help="Path to a directory containing a model checkpoint.")
group.add_argument("--tp-size", default=1, type=int)
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
group.add_argument("--pose_prob", default=0.0, type=float, help="Sample-level PoSE probability")
group.add_argument("--pose_scaling_factor",default=1.0, type=float,help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor")
group.add_argument("--rope_scaling_type",default="", type=str, choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],help="Context scaling type")
group.add_argument("--rope_scaling_factor", default=1, type=int, help="Context scaling factor")
group.add_argument("--orig_max_length", default=8192, type=int, help="Original context length before context extending")
args = parser.parse_args()
bmt.init_distributed(seed=42, tp_size=1)
bmt.synchronize()
'''tokenizer'''
tokenizer_path = "apps/fm9g_2b/tokenizer/tokenizer.model"
tokenizer = get_tokenizer(tokenizer_path)
'''model'''
model = get_model(args)
'''inference'''
Loading…
Cancel
Save