上一篇我们已经介绍了关于flash attention大部分的技术细节,这些细节同时适用于flash attention1和flash attention2。

flash attention1已经实现了较为显著的性能提升,但是也仅达到了25%~40%的GEMM(General Matrix Multiply)的理论最大FLOPs/s。flash attention的作者通过分析,发现是由于在GPU的不同线程块和warps上的任务切分还不够优化,造成了一些低利用率或者不必要的共享内存的读写操作。进而作者又提出了flash attention2算法,对任务的切分进行了优化,具体来说主要有:(1)调整算法,减少了非矩阵乘法的FLOPs。在深度学习中,通常会使用矩阵乘法运算来进行前向传播和反向传播。这是因为矩阵乘法是一种高效的数值运算,可以在现代硬件上被高效地实现。然而,并不是所有的运算都可以被表示成矩阵乘法的形式。有些运算可能需要使用其他的数值计算方法,这些方法可能会涉及到更多的浮点运算。(2)更大程度的提高了attention计算的并行度,甚至对于单个头的计算,也会将其分发到多个不同的线程块中执行计算,此举相比flash attention1,大约有2x的性能提升。

关于flash attention2对GPU warps的优化调整,flash attention2的论文中有一处说明,如下图所示。

flash attention1的forward计算中,对于每一个block,是将K,V切分到4个不同的warps(warps 是NVIDIA GPU并行计算的基本单元。一个Warp通常包含32个线程,它们同时执行相同的指令,但对不同的数据进行操作。在GPU执行指令时,通常以Warps为单位进行调度,这可以充分利用GPU的并行处理能力)上,但是将Q保持为对所有的warps是可见的。关于这样修改为什么会减少shared memory的读写以提高性能,paper的原文是这么说的:

在这里我就不做过多的解释(因为我也不懂,涉及到GPU更底层的实现相关。flash attention是使用cutlass实现的,cutlass相对偏底层,从下图可以看出,cutlass比直接写CUDA会更高级一些,但是相比triton,是偏底层)。

下面我们重点放在flash attention2算法的forward计算的实现上。

flash attention2算法的计算流程如下图所示:

flash attention2与flash attention1在算法层面大部分都是相同的,只是少部分地方做了修改,因此我们不做过多的解释,直接通过代码来逐行编程实现。

import torch

torch.manual_seed(456)

N, d = 16, 8
Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))

expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1)
expected_attention = expected_softmax @ V_mat

# 分块(tiling)尺寸,以SRAM的大小计算得到
Br = 4
Bc = d

O = torch.zeros((N, d))

# 算法流程第3步,执行外循环
for block_start_Br in range(0, N, Br):
    block_end_Br = block_start_Br + Br
    # 算法流程第4步,从HBM中load Qi 的一个block到SRAM
    Qi = Q_mat[block_start_Br:block_end_Br, :]
    # 算法流程第5步,初始化每个block的值
    Oi = torch.zeros((Br, d))  # shape Br x d
    li = torch.zeros((Br, 1))  # shape Br x 1
    mi = torch.full((Br, 1), -torch.inf)  # shape Br x 1

    # 算法流程第6步,执行内循环
    for block_start_Bc in range(0, N, Bc):
        block_end_Bc = block_start_Bc + Bc

        # 算法流程第7步,load Kj, Vj到SRAM
        Kj = K_mat[block_start_Bc:block_end_Bc, :]
        Vj = V_mat[block_start_Bc:block_end_Bc, :]

        # 算法流程第8步
        Sij = Qi @ Kj.T
        # 算法流程第9步
        mi_new = torch.max(torch.column_stack([mi, torch.max(Sij, dim=1).values[:, None]]), dim=1).values[:, None]
        Pij_hat = torch.exp(Sij - mi_new)
        li = torch.exp(mi - mi_new) * li + torch.sum(Pij_hat, dim=1)[:, None]
        # 算法流程第10步
        Oi = Oi * torch.exp(mi - mi_new) + Pij_hat @ Vj

        mi = mi_new

    # 第12步
    Oi = Oi / li

    # 第14步
    O[block_start_Br:block_end_Br, :] = Oi
assert torch.allclose(O, expected_attention)

上面的实现只是将算法的计算流程进行了编程实现。但是在实际使用中,会结合GPU的能力进行大规模并行计算。目前大众开发者GPU的编程主要会使用CUDA和triton两种语言。cuda语言大家比较熟悉,triton在这里略作介绍。

triton是一种类似 Python 的开源编程语言,它能让没有 CUDA 经验的研究人员编写高效的 GPU 代码—在大多数情况下与专家编写的cuda代码不相上下。即我们使用 python语言和triton的接口编写完相关计算后,triton编译器会生成高效的cuda代码。triton是openai发布的一项技术,目前国内很多公司也在使用triton生成的cuda代码作为参考。具体的benchmark等信息可以参考openai triton

下面是flash attention2的triton代码实现。

"""
Fused Attention
===============

This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team

Extra Credits:
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)

"""

import pytest
import torch

import triton
import triton.language as tl


@triton.jit
def _attn_fwd_inner(
    acc, l_i, m_i, q,
    K_block_ptr, V_block_ptr,
    start_m, qk_scale,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    STAGE: tl.constexpr,
    offs_m: tl.constexpr,
    offs_n: tl.constexpr,
):
    # range of values handled by this stage
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    else:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        lo = tl.multiple_of(lo, BLOCK_M)
    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(K_block_ptr)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:, None]
        else:
            m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            qk = qk * qk_scale - m_ij[:, None]
        p = tl.math.exp2(qk)
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        alpha = tl.math.exp2(m_i - m_ij)
        l_i = l_i * alpha + l_ij
        # -- update output accumulator --
        acc = acc * alpha[:, None]
        # update acc
        v = tl.load(V_block_ptr)
        acc += tl.dot(p.to(tl.float16), v)
        # update m_i and l_i
        m_i = m_ij
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
    return acc, l_i, m_i


@triton.jit
def _attn_fwd(
    Q, K, V, sm_scale, M, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H,
    N_CTX: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    STAGE: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

    # block pointers
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0),
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1),
    )
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0),
    )
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # load scales
    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr)
    # stage 1: off-band
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(
            acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
            start_m, qk_scale,
            BLOCK_M, BLOCK_DMODEL, BLOCK_N,
            1, offs_m, offs_n,
        )
    # barrier makes it easier for compielr to schedule the
    # two loops independently
    tl.debug_barrier()
    # stage 2: on-band
    if STAGE & 2:
        acc, l_i, m_i = _attn_fwd_inner(
            acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
            start_m, qk_scale,
            BLOCK_M, BLOCK_DMODEL, BLOCK_N,
            2, offs_m, offs_n,
        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(m_ptrs, m_i)
    tl.store(O_block_ptr, acc.to(Out.type.element_ty))


empty = torch.empty(128, device="cuda")

class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale):
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
        assert Lk in {16, 32, 64, 128}
        o = torch.empty_like(q)
        BLOCK_M = 128
        BLOCK_N = 64 if Lk <= 64 else 32
        num_stages = 4 if Lk <= 64 else 3
        num_warps = 4
        # Tuning for H100
        if torch.cuda.get_device_capability()[0] == 9:
            num_warps = 8
            num_stages = 7 if Lk >= 64 else 3
        grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
        _attn_fwd[grid](
            q, k, v, sm_scale, M, o,
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
            q.shape[0], q.shape[1],
            N_CTX=q.shape[2],
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_DMODEL=Lk,
            STAGE=3,
            num_warps=num_warps,
            num_stages=num_stages,
        )

        ctx.save_for_backward(q, k, v, o, M)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.BLOCK_DMODEL = Lk
        ctx.causal = causal
        return o

attention = _attention.apply

我们看上面代码的这部分

p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr)
acc += tl.dot(p.to(tl.float16), v)
# update m_i and l_i
m_i = m_ij

就是算法流程图的按步计算,与我们用纯python实现的过程基本一致。我在实现python版的时,也借鉴了triton版本的相关计算过程。因此也可以发现,triton可以让我们用相对抽象的语言写出高性能cuda代码。下面我们会对triton的实现进行性能benchmark。

然后我们将cutlass实现的flash attention2(flash attention2的默认实现方式)与triton实现的flash attention2进行性能对比。

try:
    # flash attention的标准使用接口
    from flash_attn.flash_attn_interface import \
        flash_attn_qkvpacked_func as flash_attn_func
    HAS_FLASH = True
except BaseException:
    HAS_FLASH = False

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [
    triton.testing.Benchmark(
        x_names=["N_CTX"],
        x_vals=[2**i for i in range(10, 15)],
        line_arg="provider",
        line_vals=["triton"] + (["flash"] if HAS_FLASH else []),
        line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []),
        styles=[("red", "-"), ("blue", "-")],
        ylabel="ms",
        plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}",
        args={
            "H": N_HEADS,
            "BATCH": BATCH,
            "D_HEAD": D_HEAD,
            "dtype": torch.float16,
            "mode": mode,
            "causal": causal,
        },
    )
    for mode in ["fwd"]
    for causal in [True]
]


@triton.testing.perf_report(configs)
def bench_flash_attention(
    BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"
):
    assert mode in ["fwd"]
    warmup = 25
    rep = 100
    if provider == "triton":
        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
        if mode == "fwd":
            q = q.to(torch.float8_e5m2)
            k = k.to(torch.float8_e5m2)
        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
        sm_scale = 1.3
        fn = lambda: attention(q, k, v, causal, sm_scale)
        if mode == "bwd":
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    if provider == "flash":
        qkv = torch.randn(
            (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
        )
        fn = lambda: flash_attn_func(qkv, causal=causal)
        if mode == "bwd":
            o = fn()
            do = torch.randn_like(o)
            fn = lambda: o.backward(do, retain_graph=True)
        ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
    flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
    total_flops = 2 * flops_per_matmul
    if causal:
        total_flops *= 0.5
    if mode == "bwd":
        total_flops *= 2.5  # 2.0(bwd) + 0.5(recompute)
    return total_flops / ms * 1e-9


# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path=".", print_data=True)

在A100上测试,结果如下:

batch4-head48-d64 forward,单位FLOPs/s

N_CTX(context length) triton flash attention2(cutlass)
1024 123 137
2048 159 162
4096 163 159
8192 167 157
16384 167 165

从前向计算的结果来看,triton的性能在context length较长的情况下,甚至好于cutlass实现的flash attention2。

但是triton实现的flash attention2相比默认使用cutlass实现的,backward计算时,triton的性能大约是cutlass的3/4。后续有机会会补充backward的实现。