diff --git a/training_stats.py b/training_stats.py new file mode 100644 index 0000000..785fe99 --- /dev/null +++ b/training_stats.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright @2023 AI, ZHIHU Inc. (zhihu.com) +# +# @author: hsd9026 +# @date: 2023/07/07 +# + + +import copy + +from .log import logger + + +def num_parameters(model): + """Return the number of parameters of a model""" + total_params = 0 + for param_name, param in model.state_dict().items(): + # print(param_name, param.numel()) + total_params += param.numel() + return total_params + + +def num_non_embedding_parameters(model): + """Return the number of parameters of a model""" + total_params = 0 + for param_name, param in model.state_dict().items(): + if ("embed" in param_name) or ("lm_head" in param_name): + continue + # print(param_name, param.numel()) + total_params += param.numel() + return total_params + + +def estimate_parameters(config): + """Estimate the number of parameters of a model given its config, should be equal to `num_parameters(model)`""" + # embedding parameters + embedding_params = config.vocab_size * config.dim_model + self_attn_params = 4 * config.dim_model * config.dim_head * config.num_heads * config.num_layers + ff_params = 3 * config.dim_model * config.dim_ff * config.num_layers + layernorm_parameters = 2 * config.dim_model * config.num_layers + output_layernorm_parameters = config.dim_model + positional_bias = 32 + total_params = ( + embedding_params + + self_attn_params + + ff_params + + layernorm_parameters + + output_layernorm_parameters + + positional_bias + ) + + params_without_embeddings = total_params - embedding_params + + return total_params, params_without_embeddings + + +def get_flops_per_token(config): + """An estimated version of pfdays per token, i.e., the 6N in equation: Computation = 6 N * D.""" + + _, N = estimate_parameters(config) + + logger.info(">>>>>> pfdays_per_token >>>> {:,.0f}".format(N)) + + # evaluating a forward pass + C_forward = 6 * N # + 2 * n_layer * n_ctx * d_model + # unit = 10**15 * 3600 * 24 + + # C_forward = C_forward / unit + + return C_forward