ft-left-pythia-160 / megatron /model /flash_attention.py
MicheleDusi's picture
Upload folder using huggingface_hub
25fff03 verified
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import flash_attn_cuda
def _flash_attn_forward(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
num_splits=0,
generator=None,
):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, *rest = flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
return_softmax,
num_splits,
generator,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return out, softmax_lse, S_dmask
def _flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
num_splits=0,
generator=None,
):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
_, _, _, softmax_d = flash_attn_cuda.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
num_splits,
generator,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
torch.empty_like(qkv[:, 0]),
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax,
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
dout,
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
return FlashAttnQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)