From 12e0270188d0a2a34d93809b253d6b00fba32092 Mon Sep 17 00:00:00 2001 From: paxflsu4r <198028451@qq.com> Date: Mon, 20 Jan 2025 16:56:02 +0800 Subject: [PATCH] ADD file via upload --- blocks.py | 300 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 300 insertions(+) create mode 100644 blocks.py diff --git a/blocks.py b/blocks.py new file mode 100644 index 0000000..b7412f3 --- /dev/null +++ b/blocks.py @@ -0,0 +1,300 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import bmtrain as bmt +import torch + +from .attention import Attention +from .feedforward import FeedForward +from .layernorm import LayerNorm +from .position_embedding import RotaryEmbedding +from .position_embedding import RotaryEmbeddingESM + + +class SelfAttentionBlock(bmt.DistributedModule): + """The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.half. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + num_heads: int, + num_kv_heads: int, + dim_head: int, + dtype=torch.half, + eps: float = 1e-5, + dropout_p: Optional[float] = None, + scale: bool = True, + add_qkv_bias: bool = False, + use_flash_attn: bool = False, + tp: int = 0, + ): + super().__init__() + + self.layernorm_before_attention = LayerNorm( + dim_model, + dtype=dtype, + eps=eps, + ) + + self.self_attention = Attention( + dim_model=dim_model, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + dim_head=dim_head, + dtype=dtype, + dropout_p=dropout_p, + scale=scale, + add_qkv_bias=add_qkv_bias, + use_flash_attn=use_flash_attn, + tp=tp, + ) + + if dropout_p: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + position_bias: Union[torch.Tensor, RotaryEmbedding, RotaryEmbeddingESM] = None, + use_cache: bool = False, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + pos_bias_type: Optional[str] = "relative", + length_mask: Optional[torch.Tensor] = None, + attention_mask_bias: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: int = None, + position_ids: Optional[torch.Tensor] = None, + ): + """ + Args: + hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of self-attention block. It can be the embedding of a batch of sequences. + attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation. + position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of attention block. + + """ # noqa: E501 + x = self.layernorm_before_attention(hidden_states) + x = self.self_attention( + x, + x, + attention_mask, + position_bias, + use_cache, + past_key_value, + pos_bias_type=pos_bias_type, + length_mask=length_mask, + attention_mask_bias=attention_mask_bias, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + position_ids=position_ids, + ) + if use_cache: + x, current_key_value = x + else: + current_key_value = None + + if self.dropout is not None: + x = self.dropout(x) + hidden_states = hidden_states + x # / 1.05 + if use_cache: + return hidden_states, current_key_value + else: + return hidden_states + + +class FFNBlock(torch.nn.Module): + """The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + dtype (optional): Defaults to torch.half. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + activate_fn: str, + dtype=torch.half, + eps: float = 1e-6, + dropout_p: Optional[float] = 0, + scale: bool = True, + tp: int = 0, + ): + super().__init__() + + self.layernorm_before_ffn = LayerNorm( + dim_model, + dtype=dtype, + eps=eps, + ) + + self.ffn = FeedForward( + dim_model, + dim_ff, + activate_fn=activate_fn, + dtype=dtype, + dropout_p=dropout_p, + scale=scale, + tp=tp, + ) + + if dropout_p: + self.dropout = torch.nn.Dropout(dropout_p) + else: + self.dropout = None + + def forward( + self, + hidden_states: torch.Tensor, + ): + """ + Args: + hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Hidden states before feed forward layer. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of feed-forward block + + """ # noqa: E501 + x = self.layernorm_before_ffn(hidden_states) + x = self.ffn(x) + if self.dropout is not None: + x = self.dropout(x) + hidden_states = hidden_states + x # / 1.05 + return hidden_states + + +class TransformerBlock(torch.nn.Module): + """The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block. + + Args: + dim_model (int): main dimension of modules in transformer blocks. + dim_ff (int): dim_ff used in :py:class:`model_center.layer.FeedForward`. + num_heads (int): num_heads used in :py:class:`model_center.layer.Attention`. + dim_head (int): dim_head used in :py:class:`model_center.layer.Attention`. + dtype (optional): Defaults to torch.half. + eps (float, optional): eps used in :py:class:`model_center.layer.LayerNorm`. Defaults to 1e-5. + dropout_p (float, optional): Defaults to 0. + """ # noqa: E501 + + def __init__( + self, + dim_model: int, + dim_ff: int, + num_heads: int, + num_kv_heads: int, + dim_head: int, + activate_fn: str = "gelu", + dtype=torch.half, + eps: float = 1e-6, + dropout_p: Optional[float] = None, + scale: bool = True, + add_qkv_bias: bool = False, + mask_att: bool = False, + mask_ffn: bool = False, + use_flash_attn: bool = False, + tp: int = 0, + ): + super().__init__() + self.mask_att = mask_att + self.mask_ffn = mask_ffn + + if not self.mask_att: + self.self_att = SelfAttentionBlock( + dim_model=dim_model, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + dim_head=dim_head, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + scale=scale, + add_qkv_bias=add_qkv_bias, + use_flash_attn=use_flash_attn, + tp=tp, + ) + + if not self.mask_ffn: + self.ffn = FFNBlock( + dim_model=dim_model, + dim_ff=dim_ff, + activate_fn=activate_fn, + dtype=dtype, + eps=eps, + dropout_p=dropout_p, + scale=scale, + tp=tp, + ) + + def forward( + self, + self_hidden_states: torch.Tensor, + self_attention_mask: torch.Tensor = None, + self_position_bias: Optional[torch.Tensor] = None, + use_cache: bool = False, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + pos_bias_type: Optional[str] = "relative", + length_mask: Optional[torch.Tensor] = None, + attention_mask_bias: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + """ + Args: + self_hidden_states (:obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``): Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences. + self_attention_mask (:obj:`torch.Tensor` of shape ``(batch, seq_self, seq_self)``): Avoid invalid areas to participate in the calculation of self-attention. + self_position_bias (:obj:`torch.Tensor` of shape ``(num_heads, seq_self, seq_self)``): Provide positional information to self-attention block. + + Return: + :obj:`torch.Tensor` of shape ``(batch, seq_self, dim_model)``: The output of transformer block. + + """ # noqa: E501 + # (batch, dim_model, seq_self) + current_key_value = None + if not self.mask_att: + hidden_states = self.self_att( + self_hidden_states, + attention_mask=self_attention_mask, + position_bias=self_position_bias, + use_cache=use_cache, + past_key_value=past_key_value, + pos_bias_type=pos_bias_type, + length_mask=length_mask, + attention_mask_bias=attention_mask_bias, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + position_ids=position_ids, + ) + + if use_cache: + hidden_states, current_key_value = hidden_states + else: + hidden_states = self_hidden_states + + # (batch, dim_model, seq_self) + if not self.mask_ffn: + hidden_states = self.ffn(hidden_states) + + if use_cache: + return hidden_states, current_key_value + else: + return hidden_states