|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import flash_attn_cuda |
|
|
|
|
|
|
|
|
def _flash_attn_forward( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
out, |
|
|
cu_seqlens_q, |
|
|
cu_seqlens_k, |
|
|
max_seqlen_q, |
|
|
max_seqlen_k, |
|
|
dropout_p, |
|
|
softmax_scale, |
|
|
causal, |
|
|
return_softmax, |
|
|
num_splits=0, |
|
|
generator=None, |
|
|
): |
|
|
""" |
|
|
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means |
|
|
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. |
|
|
Don't change it unless you know what you're doing. |
|
|
""" |
|
|
softmax_lse, *rest = flash_attn_cuda.fwd( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
out, |
|
|
cu_seqlens_q, |
|
|
cu_seqlens_k, |
|
|
max_seqlen_q, |
|
|
max_seqlen_k, |
|
|
dropout_p, |
|
|
softmax_scale, |
|
|
False, |
|
|
causal, |
|
|
return_softmax, |
|
|
num_splits, |
|
|
generator, |
|
|
) |
|
|
|
|
|
|
|
|
S_dmask = rest[0] if return_softmax else None |
|
|
return out, softmax_lse, S_dmask |
|
|
|
|
|
|
|
|
def _flash_attn_backward( |
|
|
dout, |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
out, |
|
|
softmax_lse, |
|
|
dq, |
|
|
dk, |
|
|
dv, |
|
|
cu_seqlens_q, |
|
|
cu_seqlens_k, |
|
|
max_seqlen_q, |
|
|
max_seqlen_k, |
|
|
dropout_p, |
|
|
softmax_scale, |
|
|
causal, |
|
|
num_splits=0, |
|
|
generator=None, |
|
|
): |
|
|
""" |
|
|
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or |
|
|
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. |
|
|
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel |
|
|
as num_splits=3), so effectively the choices are 0, 1, and 2. |
|
|
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. |
|
|
""" |
|
|
_, _, _, softmax_d = flash_attn_cuda.bwd( |
|
|
dout, |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
out, |
|
|
softmax_lse, |
|
|
dq, |
|
|
dk, |
|
|
dv, |
|
|
cu_seqlens_q, |
|
|
cu_seqlens_k, |
|
|
max_seqlen_q, |
|
|
max_seqlen_k, |
|
|
dropout_p, |
|
|
softmax_scale, |
|
|
False, |
|
|
causal, |
|
|
num_splits, |
|
|
generator, |
|
|
) |
|
|
|
|
|
|
|
|
return dq, dk, dv, softmax_d |
|
|
|
|
|
|
|
|
class FlashAttnQKVPackedFunc(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx, |
|
|
qkv, |
|
|
cu_seqlens, |
|
|
max_seqlen, |
|
|
dropout_p, |
|
|
softmax_scale, |
|
|
causal, |
|
|
return_softmax, |
|
|
): |
|
|
|
|
|
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None |
|
|
if softmax_scale is None: |
|
|
softmax_scale = qkv.shape[-1] ** (-0.5) |
|
|
out, softmax_lse, S_dmask = _flash_attn_forward( |
|
|
qkv[:, 0], |
|
|
qkv[:, 1], |
|
|
qkv[:, 2], |
|
|
torch.empty_like(qkv[:, 0]), |
|
|
cu_seqlens, |
|
|
cu_seqlens, |
|
|
max_seqlen, |
|
|
max_seqlen, |
|
|
dropout_p, |
|
|
softmax_scale, |
|
|
causal=causal, |
|
|
return_softmax=return_softmax, |
|
|
) |
|
|
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state) |
|
|
ctx.dropout_p = dropout_p |
|
|
ctx.max_seqlen = max_seqlen |
|
|
ctx.softmax_scale = softmax_scale |
|
|
ctx.causal = causal |
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, dout, *args): |
|
|
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors |
|
|
if rng_state is not None: |
|
|
cur_rng_state = torch.cuda.get_rng_state() |
|
|
torch.cuda.set_rng_state(rng_state) |
|
|
dqkv = torch.empty_like(qkv) |
|
|
_flash_attn_backward( |
|
|
dout, |
|
|
qkv[:, 0], |
|
|
qkv[:, 1], |
|
|
qkv[:, 2], |
|
|
out, |
|
|
softmax_lse, |
|
|
dqkv[:, 0], |
|
|
dqkv[:, 1], |
|
|
dqkv[:, 2], |
|
|
cu_seqlens, |
|
|
cu_seqlens, |
|
|
ctx.max_seqlen, |
|
|
ctx.max_seqlen, |
|
|
ctx.dropout_p, |
|
|
ctx.softmax_scale, |
|
|
ctx.causal, |
|
|
) |
|
|
if rng_state is not None: |
|
|
torch.cuda.set_rng_state(cur_rng_state) |
|
|
return dqkv, None, None, None, None, None, None |
|
|
|
|
|
|
|
|
def flash_attn_unpadded_qkvpacked_func( |
|
|
qkv, |
|
|
cu_seqlens, |
|
|
max_seqlen, |
|
|
dropout_p, |
|
|
softmax_scale=None, |
|
|
causal=False, |
|
|
return_attn_probs=False, |
|
|
): |
|
|
return FlashAttnQKVPackedFunc.apply( |
|
|
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs |
|
|
) |
|
|
|