diff --git a/position_embedding.py b/position_embedding.py new file mode 100644 index 0000000..b769fa2 --- /dev/null +++ b/position_embedding.py @@ -0,0 +1,368 @@ +import math +from typing import Tuple +from typing import Union + +import bmtrain as bmt +import torch +import torch.nn.functional as F + +try: + from flash_attn.layers.rotary import apply_rotary_emb_func +except: + apply_rotary_emb_func = None + + +class SegmentPositionEmbedding(bmt.DistributedModule): + def __init__( + self, + num_heads: int, + num_segments: int = 1, + num_buckets: int = 32, + max_distance: int = 128, + bidirectional: bool = False, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() + + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.bidirectional = bidirectional + self.num_segments = num_segments + + self.relative_attention_bias = bmt.DistributedParameter( + torch.empty(num_segments * num_segments + num_buckets, num_heads, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + ) + + def forward( + self, + key_pos: torch.Tensor, + query_pos: torch.Tensor, + key_segment: torch.Tensor, + query_segment: torch.Tensor, + ): + with torch.no_grad(): + batch = key_pos.size(0) + keylen = key_pos.size(1) + querylen = query_pos.size(1) + + assert key_pos.size(0) == query_pos.size(0) + assert keylen == key_segment.size(1) and querylen == query_segment.size(1) + + key_pos = key_pos.view(batch, -1, keylen) + query_pos = query_pos.view(batch, querylen, -1) + key_segment = key_segment.view(batch, -1, keylen) + query_segment = query_segment.view(batch, querylen, -1) + + relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment) + relative_position_bucket = relative_position_bucket + self.num_buckets # 与相对位置编码区间不重叠 + + # b*q*k + absolute_position_bucket = self._position_bucket( + torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :] + - torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None], + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + relative_position_bucket = torch.where( + (key_segment == query_segment), + absolute_position_bucket[None, :, :], + relative_position_bucket, + ) + # (batch, len_q, len_k) + + # (batch, len_q, len_k, num_heads) + embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) + # (batch, num_heads, len_q, len_k) + embeds = embeds.permute(0, 3, 1, 2).contiguous() + return embeds + + def _segment_relative_position_bucket(self, query_segment, key_segment): + return query_segment * self.num_segments + key_segment + + def _position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.int32) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1), + ) + relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) + return relative_buckets + + +class BucketPositionBias(bmt.DistributedModule): + def __init__( + self, + num_heads: int, + num_buckets: int = 32, + num_segment_bucket: int = 32, + max_distance: int = 128, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ) -> None: + super().__init__() + + self.num_heads = num_heads + self.num_buckets = num_buckets + self.num_segment_bucket = num_segment_bucket + self.max_distance = max_distance + + self.relative_attention_bias = bmt.DistributedParameter( + torch.empty(num_buckets + num_segment_bucket, num_heads, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + ) + + def forward( + self, + query_pos: torch.Tensor, # (batch, len_q) + key_pos: torch.Tensor, # (batch, len_k) + rel_buckets: torch.Tensor, # (batch, len_q, len_k) + ): + with torch.no_grad(): + batch = key_pos.size(0) + keylen = key_pos.size(1) + querylen = query_pos.size(1) + + assert key_pos.size(0) == query_pos.size(0) + assert rel_buckets.size(0) == batch and rel_buckets.size(1) == querylen and rel_buckets.size(2) == keylen + + relative_position_bucket = rel_buckets - 1 + self.num_buckets # 与相对位置编码区间不重叠 + + # b*q*k + inner_segment_bucket = self._position_bucket( + key_pos[..., None, :] - query_pos[..., :, None], + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + relative_position_bucket = torch.where( + rel_buckets == 0, + inner_segment_bucket, + relative_position_bucket, + ) + # (batch, len_q, len_k) + + # (batch, len_q, len_k, num_heads) + embeds = F.embedding(relative_position_bucket, self.relative_attention_bias) + # (batch, num_heads, len_q, len_k) + embeds = embeds.permute(0, 3, 1, 2).contiguous() + return embeds + + def _position_bucket(self, relative_position, num_buckets=32, max_distance=128): + relative_buckets = 0 + num_buckets //= 2 + relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets + relative_position = torch.abs(relative_position) + + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.int32) + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1), + ) + relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large) + return relative_buckets + + +class RotaryEmbedding(bmt.DistributedModule): + def __init__( + self, + dim, + base: Union[int, float] = 10000, + distance_scale: Union[int, float] = 1, + dtype: torch.dtype = torch.half, + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)) + inv_freq = inv_freq.to(dtype) + self.distance_scale = distance_scale + self.dtype = dtype + self.inv_freq = inv_freq + + def forward(self, x: torch.Tensor, x_pos: torch.Tensor): + """ + Args: + x (:obj:`torch.Tensor` of shape ``(..., dim)``): Inputs. + x_pos (:obj:`torch.Tensor` of shape ``(...)``): Positions of inputs. + """ + x_pos = x_pos * self.distance_scale + freqs = x_pos[..., None].to(self.dtype) * self.inv_freq[None, :] # (..., dim/2) + + # the same implementation as sat + emb = torch.cat((freqs, freqs), dim=-1) # (..., dim) + emb_cos = emb.cos() # (..., dim) + emb_sin = emb.sin() # (..., dim) + + rotate_x = torch.cat([-x[..., x.size(-1) // 2 :], x[..., : x.size(-1) // 2]], dim=-1) # (..., dim) + + return x * emb_cos + rotate_x * emb_sin + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, seq_dim, offset): + if x.size(seq_dim) < cos.size(seq_dim): # == do not need narrow + cos = cos.narrow(seq_dim, offset, x.size(seq_dim)) + sin = sin.narrow(seq_dim, offset, x.size(seq_dim)) + return (x * cos) + (rotate_half(x) * sin) + + +def unpad_apply_rotary_pos_emb(x, cos, sin, seq_dim, position_ids): + cos = cos.index_select(seq_dim, position_ids.view(-1)) + sin = sin.index_select(seq_dim, position_ids.view(-1)) + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbeddingESM(bmt.DistributedModule): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__( + self, + dim: int, + base: Union[int, float] = 10000, + distance_scale: Union[int, float] = 1, + dtype=torch.half, + persistent=True, + mixed_precision=True, + ): + super().__init__() + self.base = base + self.distance_scale = distance_scale + self.dtype = dtype + + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)) + if mixed_precision: + self.register_buffer("inv_freq", inv_freq, persistent=persistent) + else: + self.register_buffer("inv_freq", inv_freq.to(self.dtype), persistent=persistent) + + self._seq_len_cached = -1 + self._cos_cached = None + self._sin_cached = None + self.mixed_precision = mixed_precision + + self.apply_rotary_pos_emb = apply_rotary_pos_emb + self.unpad_apply_rotary_pos_emb = unpad_apply_rotary_pos_emb + + def _update_cos_sin_tables(self, x, seq_dim, seq_len): + if seq_len > self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t * self.distance_scale, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + for i in range(x.dim() - 1): + if i != seq_dim: + emb = emb.unsqueeze_(i) + if self.mixed_precision: + self._cos_cached = emb.cos().to(self.dtype) + self._sin_cached = emb.sin().to(self.dtype) + else: + self._cos_cached = emb.cos() + self._sin_cached = emb.sin() + return self._cos_cached, self._sin_cached + + def forward( + self, q: torch.Tensor, k: torch.Tensor, seq_dim, offset=0, cu_seqlens=None, max_length=None, position_ids=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_dim = (seq_dim + k.dim()) % k.dim() + if cu_seqlens is None: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, k.size(seq_dim) + offset) + return ( + self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, offset), + self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, offset), + ) + else: + assert offset == 0, "past kv is not supported in flash attn" + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dim, max_length) + return ( + self.unpad_apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dim, position_ids), + self.unpad_apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dim, position_ids), + ) + + +@torch.jit.script +def apply_chatglm_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [b, np, sq, hn] + x = x.permute(2, 0, 1, 3) # [b, np, sq, hn] -> [sq, b, np, hn] + sq, b, np, hn = x.shape + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + ret = torch.cat((x_out2, x_pass), dim=-1) + ret = ret.permute(1, 2, 0, 3) # [sq, b, np, hn] -> [b, np, sq, hn] + return ret + + +class ChatGLMRotaryEmbedding(bmt.DistributedModule): + def __init__(self, dim, device="cuda", dtype=torch.float16, persistent=True): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=dtype, device=device) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=persistent) + self.dim = dim + + def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset: int = 0): + return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)