diff --git a/modeling_dragonfly.py b/modeling_dragonfly.py new file mode 100644 index 0000000..66db290 --- /dev/null +++ b/modeling_dragonfly.py @@ -0,0 +1,1170 @@ +# typing: strict +# coding=utf-8 +# Copyright 2024 QiYuan Inc. + +import math +from typing import Optional +from typing import Tuple + +import bmtrain as bmt +import torch +import torch.nn.functional as F +from flash_attn.flash_attn_interface import flash_attn_varlen_func + +from .configuration_dragonfly import DragonflyConfig # from fm9g.utils import Config + +# TODO: +# 1. add scale_emb to embed and layernorm +# 2. add scale_width to all layers +# 3. add scale_depth to residual + + +class ScaledRotaryEmbeddingESM(bmt.DistributedModule): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + Add multiple Positional Interpolation methods: + + [Linear](http://arxiv.org/abs/2306.15595) + + [NTK-aware](https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/) + + [Dynamic Scaling](https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/) + + [NTK-by-parts](https://github.com/jquesnelle/yarn/pull/1) + + [YaRN](http://arxiv.org/abs/2309.00071) + Args: + dim: Dimension of the input, attn_dim // n_heads. + max_position_embeddings: Maximum number of positions to be embedded. + base: Base of the positional encoding function. + pose_prob: Probability of using PoSE. + pose_scaling_factor: max_position_embeddings scaling factor for PoSE. + scaling_type: Type of scaling to use, one of ["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", "Dynamic YaRN", ""]. + rope_scaling_factor: RoPE Scaling factor for scaling type, new max length / before extend max length. + beta_fast: Number of rotations to use for fast angular velocity. + beta_slow: Number of rotations to use for slow angular velocity. + extrapolation_factor: [0, 1], 0 is fully extrapolation, 1 is fully NTK-by-parts/YaRN. + attn_factor: Uniform attn scale factor for tuning YaRN, 1 is best for LLaMA-1/2. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + pose_prob: float = 0.0, + pose_scaling_factor: float = 1.0, + scaling_type: str = "", + rope_scaling_factor: float = 1.0, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + extrapolation_factor: int = 1, + attn_factor: int = 1, + original_max_position_embeddings: int = 2048, + persistent: bool = True, + dynamic_scaling_seq_len: int = 512, + device=None, + ): + assert scaling_type in ["Linear", "NTK-aware", "Dynamic NTK", "NTK-by-parts", "YaRN", ""] + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.persistent = persistent + self.device = device + # scaling config + self.scaling_type = scaling_type + self.pose_scaling_factor = pose_scaling_factor + self.rope_scaling_factor = rope_scaling_factor + # PoSE + self.pose_prob = pose_prob + # NTK-by-parts and YaRN args + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.original_max_position_embeddings = original_max_position_embeddings + + if pose_prob > 0: + self.scaled_max_position_embeddings = int(max_position_embeddings * pose_scaling_factor) + else: + self.scaled_max_position_embeddings = max_position_embeddings + + if self.scaling_type == "NTK-aware": + base = self.base * (self.rope_scaling_factor ** (self.dim / (self.dim - 2))) + else: + base = self.base + # TODO: Implement base NTK-aware in NTK-by-parts + if self.scaling_type in ["NTK-by-parts", "YaRN"]: + self._ntk_parts_update_inv_freq(self.scaled_max_position_embeddings) + else: + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(self.device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # Get n-d magnitude scaling corrected for interpolation + self.m_scale = float(self._get_m_scale(self.rope_scaling_factor) * self.attn_factor) + self._set_cos_sin_cache(dynamic_scaling_seq_len) + + def _get_m_scale(self, scale=1.0): + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + def _ntk_parts_update_inv_freq(self, seq_len): + # Inverse dim formula to find dim based on number of rotations + def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + # Find dim range bounds based on rotations + def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(self.device) / self.dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (self.rope_scaling_factor * pos_freqs) + low, high = find_correction_range( + self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(self.device) + ) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=self.persistent) + + def _set_cos_sin_cache(self, seq_len, device=None): + self.max_seq_len_cached = seq_len + if device is not None: + self.device = device + + if self.scaling_type == "Dynamic NTK" and seq_len > self.max_position_embeddings: + base = self.base * ( + (self.rope_scaling_factor * seq_len / self.max_position_embeddings) - (self.rope_scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(self.device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=self.persistent) + + t = torch.arange(self.max_seq_len_cached, device=self.device).type_as(self.inv_freq) + if self.scaling_type == "Linear": + freqs = torch.outer(t / self.rope_scaling_factor, self.inv_freq.to(device=t.device).to(t.dtype)) + else: + freqs = torch.outer(t, self.inv_freq.to(device=t.device).to(t.dtype)) + emb = torch.cat((freqs, freqs), dim=-1) + if self.scaling_type == "YaRN": + self.register_buffer("cos_cached", (emb.cos() * self.m_scale), persistent=self.persistent) + self.register_buffer("sin_cached", (emb.sin() * self.m_scale), persistent=self.persistent) + else: + self.register_buffer("cos_cached", emb.cos(), persistent=self.persistent) + self.register_buffer("sin_cached", emb.sin(), persistent=self.persistent) + + def _rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_rotary_pos_emb(self, q, k, cos, sin, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (self._rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (self._rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_dim = (seq_dim + k.dim()) % k.dim() + # get max current seq len from all workers + if self.pose_prob > 0.0: + seq_len = torch.max(position_ids) + 1 + else: + seq_len = k.size(seq_dim) + offset + seq_len_tensor = torch.tensor(seq_len, device=self.device) + seq_len_tensor_reduced = bmt.distributed.all_reduce(seq_len_tensor, op="max") + seq_len_reduced = seq_len_tensor_reduced.item() + # update cache if needed + if seq_len_reduced > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + + cos, sin = ( + self.cos_cached[:seq_len_reduced], + self.sin_cached[:seq_len_reduced], + ) + if position_ids.dtype != torch.long: # 231108 input is int32 + position_ids = position_ids.to(dtype=torch.long) + if cu_seqlens is None: + q_embed, k_embed = self._apply_rotary_pos_emb(q, k, cos, sin, position_ids) + else: + assert offset == 0, "past kv is not supported in flash attn" + q_embed, k_embed = self._apply_rotary_pos_emb(q, k, cos, sin, position_ids.view(-1)) + + return q_embed, k_embed + + +def Linear(*args, **kwargs): + tp = kwargs.pop("tp", 0) + if tp == 0: + return NormalLinear(*args, **kwargs) + if tp == 1: + return ColumnParallelLinear(*args, **kwargs) + if tp == 2: + return RowParallelLinear(*args, **kwargs) + + +class NormalLinear(bmt.DistributedModule): + def __init__( + self, + dim_in: int, + dim_out: int, + dtype: torch.dtype = torch.bfloat16, + init_mean: float = 0.0, + init_std: float = 0.02, + ): + super().__init__() + self.dim_in = self.in_features = dim_in + self.dim_out = self.out_features = dim_out + + # TODO:init + # init_std = 1 / ((dim_in + dim_out) ** 0.5) + + self.weight = bmt.DistributedParameter( + torch.empty((dim_out, dim_in), dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y. + """ # noqa: E501 + + x = F.linear(x, self.weight, None) + + return x + + +class ColumnParallelLinear(bmt.DistributedModule): + def __init__( + self, + dim_in: int, + dim_out: int, + dtype: torch.dtype = torch.bfloat16, + init_mean: float = 0.0, + init_std: float = 0.02, + gather_output=False, + gather_input=True, + ): + super().__init__() + assert dim_out % bmt.config["tp_size"] == 0 + + # TODO: init + # init_std = 1 / ((dim_in + dim_out) ** 0.5) + + dim_out = dim_out // bmt.config["tp_size"] + self.dim_in = self.in_features = dim_in + self.dim_out = self.out_features = dim_out + self.gather_input = gather_input + self.gather_output = gather_output + + self.weight = bmt.DistributedParameter( + torch.empty((dim_out, dim_in), dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_split_dim=0, + tp_mode=True, + ) + self.bias = None + + def forward(self, x: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y. + """ # noqa: E501 + + x = bmt.nn.OpParallelLinear.apply( + x, self.weight, self.bias, self.gather_input, self.gather_output, False, None, 1 + ) + + return x + + +class RowParallelLinear(bmt.DistributedModule): + def __init__( + self, + dim_in: int, + dim_out: int, + dtype: torch.dtype = torch.bfloat16, + init_mean: float = 0.0, + init_std: float = 0.02, + split_input=False, + all_reduce_output=False, + ): + super().__init__() + assert dim_in % bmt.config["tp_size"] == 0 + # init_std = 1 / ((dim_in + dim_out) ** 0.5) + + dim_in = dim_in // bmt.config["tp_size"] + self.dim_in = self.in_features = dim_in + self.dim_out = self.out_features = dim_out + + self.split_input = split_input + self.all_reduce_output = all_reduce_output + + self.weight = bmt.DistributedParameter( + torch.empty((dim_out, dim_in), dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_split_dim=1, + tp_mode=True, + ) + + def forward(self, x: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of linear layer + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of the linear transform y. + """ # noqa: E501 + if not self.all_reduce_output: + x = x.view(x.shape[0] * bmt.config["tp_size"], -1, x.shape[-1]) + + x = bmt.nn.OpParallelLinear.apply( + x, self.weight, None, self.split_input, False, self.split_input, 1 if self.all_reduce_output else 2, 1 + ) + + return x + + +@torch.jit.script +def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float): + old_dtype = hidden.dtype + variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) + hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype) + return hidden * weight + + +class LayerNorm(bmt.DistributedModule): + """RMS LayerNorm""" + + def __init__( + self, + dim_norm: int, + dtype: torch.dtype = torch.bfloat16, + eps: float = 1e-6, + init_var: float = 1.0, + ): + super().__init__() + + self.eps = eps + self.dim_norm = dim_norm + self.weight = bmt.DistributedParameter(torch.full((dim_norm,), init_var, dtype=dtype)) + + def forward(self, x: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``): Input tensor that need to be normalized. + Return: + :obj:`torch.Tensor` of shape ``(batch_size, seq_len, dim_norm)``: The layernorm output. + """ # noqa: E501 + assert x.size(-1) == self.dim_norm + return rms_layernorm(x, self.weight, self.eps) + + +class DenseGatedACT(bmt.DistributedModule): + def __init__( + self, + dim_in: int, + dim_ff: int, + activate_fn: str = "silu", + dtype=torch.bfloat16, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + ): + super().__init__() + + _std = init_std / math.sqrt(scale_width) if scale else init_std + + self.w_0 = Linear( + dim_in=dim_in, + dim_out=dim_ff, + dtype=dtype, + tp=tp, + init_std=_std, + ) + + self.w_1 = Linear(dim_in=dim_in, dim_out=dim_ff, dtype=dtype, tp=tp, init_std=_std) + + if activate_fn == "gelu": + self.act = torch.nn.GELU() + elif activate_fn == "silu": + self.act = torch.nn.functional.silu + else: + raise NotImplementedError(f"{activate_fn} is not supported") + + def forward(self, x: torch.Tensor): + """This model inherits from bmt.DistributedModule. + Transform an input tensor from one feature space to another via a nonlinear operation + + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): Tensor that will be subject to nonlinear operations. + + Return: + out (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_ff)``) + + """ # noqa: E501 + gate_score = self.act(self.w_0(x)) + x = self.w_1(x) + + x = gate_score * x + return x + + +class FeedForward(bmt.DistributedModule): + r"""FeedForward module + + Args: + dim_in (int): input dimension. + dim_ff (int): middle dimension. + dim_out (int, optional): output dimension. Defaults to None, which means dim_in = dim_out. + dtype (optional): Defaults to torch.bfloat16. + init_mean (float, optional): mean of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0. + init_std (float, optional): std of :math:`\mathbf{W}\sim\mathcal{N}(\text{mean}, \text{std}^2)` for fully-connected module used in feed-forward layer. Defaults to 0.02. + bias (bool, optional): whether to use bias term in fully-connected layers used in feed-forward module. Defaults to False. + activate_fn (str, optional): Defaults to `gated_gelu`. + dropout_p (int, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + activate_fn: str = "silu", + dtype=torch.bfloat16, + dropout_p: Optional[float] = None, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + ): + super().__init__() + + self.w_in = DenseGatedACT( + dim_in=dim_model, + dim_ff=dim_ff, + activate_fn=activate_fn, + dtype=dtype, + scale=scale, + init_std=init_std, + scale_width=scale_width, + ) + + if dropout_p is not None: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + _std = init_std / math.sqrt(scale_width) if scale else init_std + self.w_out = Linear(dim_in=dim_ff, dim_out=dim_model, dtype=dtype, init_std=_std) + + def forward(self, x: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_in)``): The input of feed-forward module. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_len, dim_out)``: The output of feed-forward module. + """ # noqa: E501 + x = self.w_in(x) + + if self.dropout is not None: + x = self.dropout(x) + + x = self.w_out(x) + + return x + + +class Embedding(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.bfloat16, + init_mean: float = 0.0, + init_std: float = 1, + scale: bool = False, + scale_emb: float = 1.0, + scale_width: float = 1.0, + tp: int = 0, + ): + super().__init__() + + self.dim_model = embedding_size + self.weight = bmt.DistributedParameter( + torch.empty(vocab_size, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + ) + self.tp = tp + self.scale = scale + self.scale_emb = scale_emb + self.scale_width = scale_width + + def forward(self, x: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. + """ # noqa: E501 + + if self.tp: + x = x.view(-1).chunk(bmt.config["tp_size"])[bmt.config["tp_rank"]].view(x.size(0), -1) + + embeds = F.embedding(x, self.weight) + + if self.scale: + embeds = embeds * self.scale_emb + + return embeds + + def projection(self, x: torch.Tensor): + """ + Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. + """ # noqa: E501 + + if self.scale: + x = x / self.scale_width # TODO: check if it is ok to add before all_gather + + logits = F.linear(x, self.weight) + return logits + + +class Attention(bmt.DistributedModule): + def __init__( + self, + dim_model: int, + num_heads: int, + num_kv_heads: int, + dim_head: int, + dtype: torch.dtype = torch.bfloat16, + dropout_p: Optional[float] = None, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + qk_norm: bool = False, + ) -> None: + super().__init__() + + self.dim_model = dim_model + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_groups = num_heads // num_kv_heads + self.dim_head = dim_head + + self.scale = scale + _std = init_std / math.sqrt(scale_width) if scale else init_std + + self.project_q = Linear( + self.dim_model, + self.num_heads * self.dim_head, + dtype=dtype, + tp=tp, + init_std=_std, + ) + self.project_k = Linear( + self.dim_model, + self.num_kv_heads * self.dim_head, + dtype=dtype, + tp=tp, + init_std=_std, + ) + self.project_v = Linear( + self.dim_model, + self.num_kv_heads * self.dim_head, + dtype=dtype, + tp=tp, + init_std=_std, + ) + + self.attention_out = Linear( + self.num_heads * self.dim_head, + self.dim_model, + dtype=dtype, + tp=tp * 2, # TODO + init_std=_std, + ) + + if dropout_p is not None: + self.dropout = torch.nn.Dropout(p=dropout_p) + self.dropout_p = dropout_p + else: + self.dropout = None + + self.tp = tp + + def forward( + self, + hidden_q: torch.Tensor, + hidden_kv: torch.Tensor, + position_bias: torch.Tensor, # TODO + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: int = None, + position_ids: Optional[torch.Tensor] = None, + ): + """This model inherits from bmt.DistributedModule. + Args: + hidden_q (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): Indices of input sequence tokens. It will be embedded by model's internal embedding lookup matrix. + hidden_kv (:obj:`torch.Tensor` of shape ``(batch, len_k, dim_model)``): Length of input sequence before padding. + attention_mask (:obj:`torch.Tensor` of shape ``(batch, len_q, len_k)``): Used to avoid performing attention on padding token indices. + position_bias(:obj:`torch.Tensor` of shape ``(num_heads, len_q, len_k)`` or ``(1, num_heads, len_k, len_q)``): Provide positional information about tensor `key_value` and `query`. + Return: + out (:obj:`torch.Tensor` of shape ``(batch, len_q, dim_model)``): The attention output. + """ # noqa: E501 + + batch_size = hidden_q.size(0) + + if self.tp: + assert hidden_q.data_ptr() == hidden_kv.data_ptr() + + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0) + if self.project_q.bias is not None + else None, + True, + False, + False, + None, + 1, + ) + + hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1]) + + block_size = hidden_q.shape[-1] // (self.head_groups + 1 + 1) + h_q = hidden_q[..., : block_size * self.head_groups] + h_k = hidden_q[..., block_size * self.head_groups : block_size * (self.head_groups + 1)] + h_v = hidden_q[..., block_size * (self.head_groups + 1) :] + else: + h_q = self.project_q(hidden_q) + h_k = self.project_k(hidden_kv) + h_v = self.project_v(hidden_kv) + + len_q = h_q.size(1) + len_k = h_k.size(1) + + h_q = h_q.view(batch_size * len_q, -1, self.dim_head) + h_k = h_k.view(batch_size * len_k, -1, self.dim_head) + h_v = h_v.view(batch_size * len_k, -1, self.dim_head) + h_q, h_k = position_bias(h_q, h_k, -3, cu_seqlens=cu_seqlens, max_length=max_seqlen, position_ids=position_ids) + score = flash_attn_varlen_func( + h_q, + h_k, + h_v, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + self.dropout_p, + causal=True, + deterministic=True, + ) + + #print("DEBUG! use flash!!!!!! ARQ") + score = score.view(batch_size, len_q, -1) + score = self.attention_out(score) + + return score + + +class SelfAttentionBlock(bmt.DistributedModule): + """The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.bfloat16. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + num_heads: int, + num_kv_heads: int, + dim_head: int, + dtype=torch.bfloat16, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + scale_depth: float = -1, + qk_norm: bool = False, + layer_id: int = 0, + num_layers: int = 0, + ): + super().__init__() + + self.layernorm_before_attention = LayerNorm( + dim_model, + dtype=dtype, + eps=eps, + ) + + self.self_attention = Attention( + dim_model=dim_model, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + dim_head=dim_head, + dtype=dtype, + dropout_p=dropout_p, + tp=tp, + scale=scale, + init_std=init_std, + scale_width=scale_width, + qk_norm=qk_norm, + ) + + if dropout_p: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + self.scale = scale + self.scale_depth = scale_depth + self.num_layers = num_layers + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: ScaledRotaryEmbeddingESM, + cu_seqlens: torch.Tensor, + max_seqlen: int = None, + position_ids: Optional[torch.Tensor] = None, + ): + """ + Args: + hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences. + attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation. + position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block. + + """ # noqa: E501 + x = self.layernorm_before_attention(hidden_states) + x = self.self_attention( + x, + x, + position_bias, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + position_ids=position_ids, + ) + + if self.dropout is not None: + x = self.dropout(x) + + if self.scale_depth > 0: + hidden_states = hidden_states + x * ( + self.scale_depth / math.sqrt(self.num_layers) + ) # https://arxiv.org/pdf/2310.02244.pdf + else: + hidden_states = hidden_states + x + + return hidden_states + + +class FFNBlock(torch.nn.Module): + """The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + dtype (optional): Defaults to torch.bfloat16. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + activate_fn: str, + dtype=torch.bfloat16, + eps: float = 1e-6, + dropout_p: Optional[float] = 0, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + scale_depth: float = -1, + layer_id: int = 0, + num_layers: int = 0, + ): + super().__init__() + + self.layernorm_before_ffn = LayerNorm( + dim_model, + dtype=dtype, + eps=eps, + ) + + self.ffn = FeedForward( + dim_model, + dim_ff, + activate_fn=activate_fn, + dtype=dtype, + dropout_p=dropout_p, + tp=tp, + scale=scale, + init_std=init_std, + scale_width=scale_width, + ) + + if dropout_p: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + self.scale = scale + self.scale_depth = scale_depth + self.num_layers = num_layers + + def forward( + self, + hidden_states: torch.Tensor, + ): + """ + Args: + hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block + + """ # noqa: E501 + x = self.layernorm_before_ffn(hidden_states) + x = self.ffn(x) + if self.dropout is not None: + x = self.dropout(x) + + if self.scale_depth > 0: + hidden_states = hidden_states + x.view_as(hidden_states) * ( + self.scale_depth / math.sqrt(self.num_layers) + ) # https://arxiv.org/pdf/2310.02244.pdf + else: + hidden_states = hidden_states + x.view_as(hidden_states) + + return hidden_states + + +class TransformerBlock(torch.nn.Module): + """The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.bfloat16. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + num_heads: int, + num_kv_heads: int, + dim_head: int, + activate_fn: str = "silu", + dtype=torch.bfloat16, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + scale_depth: float = -1, + qk_norm: bool = False, + layer_id: int = 0, + num_layers: int = 0, + ): + super().__init__() + + self.self_att = SelfAttentionBlock( + dim_model=dim_model, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + dim_head=dim_head, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + tp=tp, + scale=scale, + init_std=init_std, + scale_width=scale_width, + scale_depth=scale_depth, + qk_norm=qk_norm, + layer_id=layer_id, + num_layers=num_layers, + ) + + self.ffn = FFNBlock( + dim_model=dim_model, + dim_ff=dim_ff, + activate_fn=activate_fn, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + tp=tp, + scale=scale, + init_std=init_std, + scale_width=scale_width, + scale_depth=scale_depth, + layer_id=layer_id, + num_layers=num_layers, + ) + + def forward( + self, + self_hidden_states: torch.Tensor, + self_position_bias: Optional[torch.Tensor] = None, # TODO + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + """ + Args: + self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. + self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention. + self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block. + + """ # noqa: E501 + # (batch, dim_model, seq_self) + hidden_states = self.self_att( + self_hidden_states, + position_bias=self_position_bias, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + position_ids=position_ids, + ) + + # (batch, dim_model, seq_self) + hidden_states = self.ffn(hidden_states) + + return hidden_states + + +class Encoder(bmt.DistributedModule): + """Layers of encoder transformer blocks plus an final layernorm. + + Args: + num_layers (int): number of layers. + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.bfloat16. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + num_layers: int, + dim_model: int, + dim_ff: int, + num_heads: int, + dim_head: int, + num_kv_heads: int = -1, + activate_fn: str = "silu", + dtype: torch.dtype = torch.bfloat16, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + tp: int = 0, + scale: bool = False, + init_std: float = 0.02, + scale_width: float = 1.0, + scale_depth: float = -1, + qk_norm: bool = False, + use_checkpoint: bool = True, + ): + super().__init__() + if num_kv_heads == -1: + num_kv_heads = num_heads + self.num_layers = num_layers + + self.layers = bmt.TransformerBlockList( + [ + bmt.CheckpointBlock( + TransformerBlock( + dim_model=dim_model, + dim_ff=dim_ff, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + dim_head=dim_head, + activate_fn=activate_fn, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + tp=tp, + scale=scale, + init_std=init_std, + scale_width=scale_width, + scale_depth=scale_depth, + qk_norm=qk_norm, + layer_id=layer_id, + num_layers=num_layers, + ), + use_checkpoint=use_checkpoint + ) + for layer_id in range(num_layers) + ] + ) + self.output_layernorm = LayerNorm(dim_norm=dim_model, dtype=dtype, eps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_bias: torch.Tensor = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + """ + Args: + hidden-states (:obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``): Input of encoder, might be the embedding of a batch of sequences. + attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_enc, seq_enc)``): Avoid invalid areas to participate in the calculation + position_bias(:obj:`torch.Tensor` of shape ``(num_heads, seq_enc, seq_enc)``) Provides position information to attention mechanism. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_enc, dim_model)``: The encoder output. + + """ # noqa: E501 + hidden_states = self.layers( + hidden_states, + position_bias, + cu_seqlens, + max_seqlen, + position_ids, + ) + hidden_states = self.output_layernorm(hidden_states) + return hidden_states + + +class Dragonfly(bmt.DistributedModule): + def __init__(self, config: DragonflyConfig): + super().__init__() + + self.encoder = Encoder( + num_layers=config.num_layers, + dim_model=config.dim_model, + dim_ff=config.dim_ff, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + dim_head=config.dim_head, + activate_fn=config.activate_fn, + dtype=config.dtype, + eps=config.eps, + dropout_p=config.dropout_p, + tp=config.tp, + scale=config.scale, + init_std=config.init_std, + scale_width=config.scale_width, + scale_depth=config.scale_depth, + qk_norm=config.qk_norm, + use_checkpoint=config.use_checkpoint, + ) + + self.input_embedding = Embedding( + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=config.init_std, + tp=config.tp, + scale=config.scale, + scale_emb=config.scale_emb, + scale_width=config.scale_width, + ) + + self.position_bias = ScaledRotaryEmbeddingESM( + dim=config.dim_head, + max_position_embeddings=config.max_length, + base=config.base, + pose_prob=config.pose_prob, + pose_scaling_factor=config.pose_scaling_factor, + scaling_type=config.rope_scaling_type, + rope_scaling_factor=config.rope_scaling_factor, + original_max_position_embeddings=config.orig_max_length, + dynamic_scaling_seq_len=config.max_length, # disable dynamic scaling + persistent=False, + device="cuda", + ) + + if config.tie_lm_head is False: + self.lm_head = Embedding( + vocab_size=config.vocab_size, + embedding_size=config.dim_model, + dtype=config.dtype, + init_std=config.init_std, + scale=config.scale, + scale_width=config.scale_width, + tp=config.tp, + ) + + self.config = config + + def forward( + self, + input: torch.Tensor, # (batch, seqlen) int32 + cu_seqlens: torch.Tensor = None, # (real_batch+2) int32 + max_seqlen: int = None, + position_ids: torch.Tensor = None, # (batch, seqlen) int32 + ): + hidden_states = self.input_embedding(input) + + assert cu_seqlens is not None, "cu_seqlens are needed in Flash Attention cuda impl" + hidden_states = self.encoder( + hidden_states, + position_bias=self.position_bias, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + position_ids=position_ids, + ) + + if self.config.tie_lm_head is True: + logits = self.input_embedding.projection(hidden_states) + else: + logits = self.lm_head.projection(hidden_states) + + return logits