You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
3.7 KiB
85 lines
3.7 KiB
import inspect
|
|
import json
|
|
import math
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
from collections import defaultdict
|
|
from itertools import chain
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Union
|
|
|
|
import bmtrain as bmt
|
|
import numpy as np
|
|
import torch
|
|
from bmtrain import nccl
|
|
from bmtrain.global_var import config as bmt_config
|
|
|
|
sys.path.append("../../")
|
|
from fm9g.arguments import get_args
|
|
from fm9g.dragonfly.modeling_dragonfly import Dragonfly
|
|
from fm9g.dragonfly.modeling_dragonfly import DragonflyConfig
|
|
from fm9g.dragonfly.training_tasks.pretrain_indexed import CudaPrefetcher
|
|
from fm9g.dragonfly.training_tasks.pretrain_indexed import MixedIndexedDataset
|
|
from fm9g.dragonfly.training_tasks.pretrain_indexed import UnpadBatchedMixedDataset
|
|
from fm9g.utils import exporter
|
|
from fm9g.utils import logger
|
|
from fm9g.utils.exporter import save_every_step_stats
|
|
from fm9g.utils.training_stats import num_non_embedding_parameters
|
|
from fm9g.utils.training_stats import num_parameters
|
|
|
|
from apps.fm9g_2b.pretrain_dragonfly import initialize
|
|
import argparse
|
|
def get_tokenizer(args):
|
|
from transformers import LlamaTokenizerFast
|
|
tokenizer = LlamaTokenizerFast(vocab_file=args)
|
|
return tokenizer
|
|
|
|
def get_model(args):
|
|
config = DragonflyConfig.from_json_file(args.model_config)
|
|
config.tp = 1 if args.tp_size != 1 else 0 # TODO
|
|
config.pose_prob = args.pose_prob
|
|
config.pose_scaling_factor = args.pose_scaling_factor
|
|
config.rope_scaling_type = args.rope_scaling_type
|
|
config.rope_scaling_factor = args.rope_scaling_factor
|
|
config.orig_max_length = args.orig_max_length
|
|
|
|
bmt.print_rank("model config: {}".format(config))
|
|
bmt.print_rank("bmt config: {}".format(bmt.config))
|
|
|
|
model = Dragonfly(config)
|
|
if args.load is not None:
|
|
bmt.print_rank("args.load is not None, start to load checkpoints" + args.load)
|
|
exporter.load_model_ckpt(args, model)
|
|
else:
|
|
bmt.print_rank("args.load is None, start to initialize parameters")
|
|
bmt.init_parameters(model)
|
|
return model
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
group = parser.add_argument_group("model", "model configuration")
|
|
group.add_argument("--model-config", type=str,default="apps/fm9g_2b/model_configs/2.4b.json", help="model configuration file")
|
|
group.add_argument("--eps", type=float, default=1e-5, help="eps in layernorm")
|
|
group.add_argument("--load", type=str, default="/root/autodl-tmp/fm9g_2b_hf_models",help="Path to a directory containing a model checkpoint.")
|
|
group.add_argument("--tp-size", default=1, type=int)
|
|
|
|
group = parser.add_argument_group("long_context_extend", "long context extend configurations")
|
|
group.add_argument("--pose_prob", default=0.0, type=float, help="Sample-level PoSE probability")
|
|
group.add_argument("--pose_scaling_factor",default=1.0, type=float,help="PoSE scaling factor, simulate input length = max_length * pose_scaling_factor")
|
|
group.add_argument("--rope_scaling_type",default="", type=str, choices=["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""],help="Context scaling type")
|
|
group.add_argument("--rope_scaling_factor", default=1, type=int, help="Context scaling factor")
|
|
group.add_argument("--orig_max_length", default=8192, type=int, help="Original context length before context extending")
|
|
args = parser.parse_args()
|
|
bmt.init_distributed(seed=42, tp_size=1)
|
|
bmt.synchronize()
|
|
'''tokenizer'''
|
|
tokenizer_path = "apps/fm9g_2b/tokenizer/tokenizer.model"
|
|
tokenizer = get_tokenizer(tokenizer_path)
|
|
'''model'''
|
|
model = get_model(args)
|
|
'''inference'''
|