From 1ec35926fcfd1f4e2652b69351dd5a27c72e1bf3 Mon Sep 17 00:00:00 2001 From: paxflsu4r <198028451@qq.com> Date: Mon, 20 Jan 2025 16:55:53 +0800 Subject: [PATCH] ADD file via upload --- _attention_pp_sp.py | 79 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 _attention_pp_sp.py diff --git a/_attention_pp_sp.py b/_attention_pp_sp.py new file mode 100644 index 0000000..cab9825 --- /dev/null +++ b/_attention_pp_sp.py @@ -0,0 +1,79 @@ +import torch +import torch.nn.functional as F +import bmtrain as bmt + +def _linear_backward(grad_output, x, weight, bias): + grad_x = grad_weight = grad_bias = None + if x.requires_grad: + grad_x = grad_output.matmul(weight) + if weight.requires_grad: + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_x, grad_weight, grad_bias + +class OpAttnPipeSP(torch.autograd.Function): + @staticmethod + def forward(ctx, q_w, k_w, v_w, q_b, w_b, v_b, x, cache_kv, cache_kv_inp, cu_seqlens_q, cu_seqlens_k, max_seqlen): + ctx.save_for_backward(x, q_w, k_w, v_w, q_b, w_b, v_b) + if cache_kv.numel() = 0: + q = F.linear(x, q_w, q_b) + k = F.linear(x, k_w, w_b) + v = F.linear(x, v_w, v_b) + else: + q = F.linear(x, q_w, q_b) + k = F.linear(x, k_w, w_b) + v = F.linear(x, v_w, v_b) + k = torch.cat([cache_kv[0], k], dim=1) + v = torch.cat([cache_kv[1], v], dim=1) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen, max_seqlen, 0, causal=True, window_size=(-1,-1), alibi_slopes=None, deterministic=False, return_attn_probs=False + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + ) + ctx.max_seqlen_q = max_seqlen + ctx.max_seqlen_k = max_seqlen + + + + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + False, + (-1,-1), + None, + False, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + + d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b) + d_xq, d_wq, d_bq = _linear_backward(dq, x, q_w, q_b) + d_xk, d_wk, d_bk = _linear_backward(dk, x, k_w, k_b) + + + + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None