| | """ |
| | Author: Eric Lin (xihlin) |
| | """ |
| | """ |
| | ... note(bapatra):: |
| | This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module |
| | imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal. |
| | In the future, would be really good to revisit this and refactor into a more readable file structure. |
| | |
| | """ |
| | from typing import TypeVar |
| | from functools import lru_cache |
| | import math |
| | import pytest |
| | import torch |
| | import numpy as np |
| |
|
| | import triton |
| | import triton.language as tl |
| |
|
| | import os |
| |
|
| | import dataclasses |
| |
|
| | Phi3SmallConfig = TypeVar('Phi3SmallConfig') |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclasses.dataclass |
| | class BlockSparseParams(object): |
| | block_size: int |
| | kernel_block_size: int |
| | num_local_blocks: int |
| | vert_stride: int |
| | homo_head_pattern: bool = False |
| |
|
| | @classmethod |
| | def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams": |
| | return cls( |
| | block_size=config.blocksparse_block_size, |
| | kernel_block_size=config.blocksparse_triton_kernel_block_size, |
| | num_local_blocks=config.blocksparse_num_local_blocks, |
| | vert_stride=config.blocksparse_vert_stride, |
| | homo_head_pattern=config.blocksparse_homo_head_pattern, |
| | ) |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | def dense_to_crow_col(x): |
| | ''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. |
| | param: |
| | TODO: |
| | 1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it? |
| | NOTE: col_indices padded -1 |
| | ''' |
| | pad = -1 |
| | dim = x.dim() |
| | assert x.dim() in (2, 3) |
| | if x.dim() == 2: |
| | x = x[None] |
| | x = [xi.to_sparse_csr() for xi in x] |
| | crows = torch.vstack([xi.crow_indices() for xi in x]) |
| | cols = [xi.col_indices() for xi in x] |
| | max_cols = max(len(xi) for xi in cols) |
| | cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols] |
| | cols = torch.vstack(cols) |
| | if dim == 2: |
| | crows = crows[0] |
| | cols = cols[0] |
| | return crows, cols |
| |
|
| |
|
| | def crow_col_to_dense(crows, cols, dtype=torch.float16): |
| | dim = crows.dim() |
| | if dim == 1: |
| | crows = crows[None] |
| | cols = cols[None] |
| | device = crows.device |
| | crows, cols = crows.cpu(), cols.cpu() |
| | shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) |
| | x = torch.zeros(shape, dtype=dtype) |
| | for i in range(shape[0]): |
| | for j in range(shape[1]): |
| | x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1 |
| | if dim == 1: |
| | x = x[0] |
| | return x.to(device) |
| |
|
| |
|
| | def dense_to_ccol_row(x): |
| | '''Similar, but to CSC format |
| | ''' |
| | x = x.transpose(-2, -1) |
| | return dense_to_crow_col(x) |
| |
|
| |
|
| | def ccol_row_to_dense(ccol, rows, dtype=torch.float16): |
| | return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() |
| |
|
| |
|
| | def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False): |
| | ''' |
| | :return: a tuple of 3: |
| | - tuple of crow_indices, col_indices representation of CSR format. |
| | - block dense mask |
| | - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None |
| | ''' |
| | with torch.no_grad(): |
| | N_BLOCK = triton.cdiv(N_CTX, BLOCK) |
| | q_pos = torch.arange(N_BLOCK)[:, None] |
| | k_pos = torch.arange(N_BLOCK)[None] |
| | mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0 |
| | block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) |
| | N_BLOCK_Q = triton.cdiv(q_len, BLOCK) |
| | block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr() |
| | if return_dense: |
| | mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) |
| | causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] |
| | mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask |
| | return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense |
| | else: |
| | return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None |
| |
|
| |
|
| | def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False): |
| | ''' |
| | :return: a tuple of 3: |
| | - tuple of crow_indices, col_indices representation of CSR format. |
| | - block dense mask |
| | - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None |
| | ''' |
| | if homo_head: |
| | with torch.no_grad(): |
| | (crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense) |
| | crow = crow[None].expand(n_heads, crow.shape[0]) |
| | col = col[None].expand(n_heads, col.shape[0]) |
| | if return_dense: |
| | mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape) |
| | return (crow, col), block_mask_dense, mask_dense |
| |
|
| | with torch.no_grad(): |
| | N_BLOCK = triton.cdiv(N_CTX, BLOCK) |
| | q_pos = torch.arange(N_BLOCK)[None, :, None] |
| | k_pos = torch.arange(N_BLOCK)[None, None] |
| | head_sliding_step = max(1, int(vert_stride / n_heads)) |
| | mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)] |
| | mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) |
| | block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) |
| | N_BLOCK_Q = triton.cdiv(q_len, BLOCK) |
| | block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:] |
| | if return_dense: |
| | mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) |
| | causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] |
| | mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None] |
| | return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense |
| | else: |
| | return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None |
| |
|
| |
|
| | def get_sparse_attn_mask(q, N_CTX, *args, **kwargs): |
| | return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | @triton.jit |
| | def _fwd_kernel( |
| | Q, K, V, sm_scale, |
| | layout_crow_ptr, |
| | layout_col_ptr, |
| | layout_crow_stride_h, layout_crow_stride_m, |
| | layout_col_stride_h, layout_col_stride_m, |
| | TMP, L, M, |
| | Out, |
| | stride_qz, stride_qh, stride_qm, stride_qd, |
| | stride_kz, stride_kh, stride_kn, stride_kd, |
| | stride_vz, stride_vh, stride_vn, stride_vd, |
| | stride_oz, stride_oh, stride_om, stride_od, |
| | Z, H, N_CTX, |
| | PAST_LEN, |
| | Q_ROUNDED_LEN, |
| | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | EVEN_M_BLOCK: tl.constexpr, |
| | EVEN_N_BLOCK: tl.constexpr, |
| | INFERENCE: tl.constexpr, |
| | NUM_DBLOCKS: tl.constexpr, |
| | ): |
| | Q_LEN = N_CTX - PAST_LEN |
| | start_m = tl.program_id(0) |
| | off_hz = tl.program_id(1) |
| | off_h = off_hz % H |
| | off_z = off_hz // H |
| | Q += off_z * stride_qz + off_h * stride_qh |
| | K += off_z * stride_kz + off_h * stride_kh |
| | V += off_z * stride_vz + off_h * stride_vh |
| | |
| | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | offs_n = tl.arange(0, BLOCK_N) |
| | offs_d = tl.arange(0, BLOCK_DMODEL) |
| | off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd |
| | |
| | off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd |
| | off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd |
| | |
| | q_ptrs = Q + off_q |
| | k_ptrs = K + off_k |
| | v_ptrs = V + off_v |
| | |
| | t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m |
| | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') |
| | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
| | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
| | if NUM_DBLOCKS >= 2: |
| | acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
| |
|
| | |
| | if EVEN_M_BLOCK: |
| | q = tl.load(q_ptrs) |
| | if NUM_DBLOCKS >= 2: |
| | q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) |
| | else: |
| | q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) |
| | if NUM_DBLOCKS >= 2: |
| | q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN) |
| |
|
| | layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m |
| | start_l = tl.load(layout_ptr).to(tl.int32) |
| | end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32) |
| |
|
| | |
| | for col_idx_idx in range(start_l, end_l): |
| | col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32) |
| | start_n = col_idx * BLOCK_N |
| | |
| | if EVEN_N_BLOCK: |
| | k = tl.load(k_ptrs + start_n * stride_kn) |
| | else: |
| | k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX) |
| | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
| | qk += tl.dot(q, k) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | if EVEN_N_BLOCK: |
| | k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd) |
| | else: |
| | k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX) |
| | qk += tl.dot(q2, k) |
| |
|
| | qk *= sm_scale |
| | qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf')) |
| | |
| | m_ij = tl.max(qk, 1) |
| | p = tl.exp(qk - m_ij[:, None]) |
| | l_ij = tl.sum(p, 1) |
| | |
| | m_i_new = tl.maximum(m_i, m_ij) |
| | alpha = tl.exp(m_i - m_i_new) |
| | beta = tl.exp(m_ij - m_i_new) |
| | l_i_new = alpha * l_i + beta * l_ij |
| | |
| | |
| | p_scale = beta / l_i_new |
| | p = p * p_scale[:, None] |
| | |
| | acc_scale = l_i / l_i_new * alpha |
| | |
| | |
| | acc = acc * acc_scale[:, None] |
| | if NUM_DBLOCKS >= 2: |
| | acc2 = acc2 * acc_scale[:, None] |
| | p = p.to(Q.dtype.element_ty) |
| | |
| | if EVEN_N_BLOCK: |
| | v = tl.load(v_ptrs + start_n * stride_vn) |
| | else: |
| | v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX) |
| | acc += tl.dot(p, v) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | if EVEN_N_BLOCK: |
| | v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd) |
| | else: |
| | v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX) |
| | acc2 += tl.dot(p, v) |
| |
|
| | |
| | l_i = l_i_new |
| | m_i = m_i_new |
| |
|
| | |
| | |
| | |
| | |
| | if not INFERENCE: |
| | l_ptrs = L + off_hz * N_CTX + offs_m |
| | m_ptrs = M + off_hz * N_CTX + offs_m |
| | if EVEN_M_BLOCK: |
| | tl.store(l_ptrs, l_i) |
| | tl.store(m_ptrs, m_i) |
| | else: |
| | tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN) |
| | tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN) |
| | |
| | |
| | off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od |
| | out_ptrs = Out + off_o |
| | tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN) |
| | if NUM_DBLOCKS >= 2: |
| | tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN) |
| |
|
| |
|
| | |
| | @triton.heuristics( |
| | { |
| | 'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0, |
| | } |
| | ) |
| | @triton.jit |
| | def _bwd_preprocess( |
| | Out, DO, L, |
| | NewDO, Delta, |
| | N_CTX, |
| | BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, |
| | EVEN_M_BLOCK: tl.constexpr, |
| | ): |
| | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) |
| | off_d = tl.arange(0, D_HEAD) |
| | |
| | if EVEN_M_BLOCK: |
| | o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) |
| | do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) |
| | else: |
| | o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) |
| | do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) |
| | denom = tl.load(L + off_m).to(tl.float32) |
| | |
| | do = do / denom[:, None] |
| | delta = tl.sum(o * do, axis=1) |
| | |
| | if EVEN_M_BLOCK: |
| | tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do) |
| | else: |
| | tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX) |
| | tl.store(Delta + off_m, delta) |
| |
|
| |
|
| | |
| | @triton.heuristics( |
| | { |
| | 'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0, |
| | 'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0, |
| | } |
| | ) |
| | @triton.jit |
| | def _bwd_kernel( |
| | Q, K, V, sm_scale, |
| | layout_ccol_ptr, |
| | layout_row_ptr, |
| | layout_ccol_stride_h, layout_ccol_stride_m, |
| | layout_row_stride_h, layout_row_stride_m, |
| | Out, DO, |
| | DQ, DK, DV, |
| | L, M, |
| | D, |
| | stride_qz, stride_qh, stride_qm, stride_qd, |
| | stride_kz, stride_kh, stride_kn, stride_kd, |
| | stride_vz, stride_vh, stride_vn, stride_vd, |
| | stride_oz, stride_oh, stride_om, stride_od, |
| | |
| | Z, H, N_CTX, |
| | num_block, |
| | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | EVEN_M_BLOCK: tl.constexpr, |
| | EVEN_N_BLOCK: tl.constexpr, |
| | NUM_DBLOCKS: tl.constexpr, |
| | ): |
| | start_n = tl.program_id(0) |
| | off_hz = tl.program_id(1) |
| | off_z = off_hz // H |
| | off_h = off_hz % H |
| | |
| | Q += off_z * stride_qz + off_h * stride_qh |
| | K += off_z * stride_kz + off_h * stride_kh |
| | V += off_z * stride_vz + off_h * stride_vh |
| | DO += off_z * stride_oz + off_h * stride_oh |
| | DQ += off_z * stride_oz + off_h * stride_oh |
| | DK += off_z * stride_oz + off_h * stride_oh |
| | DV += off_z * stride_oz + off_h * stride_oh |
| | |
| | |
| |
|
| | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| | offs_m = tl.arange(0, BLOCK_M) |
| | offs_d = tl.arange(0, BLOCK_DMODEL) |
| | |
| | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd) |
| | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd) |
| |
|
| | |
| | D_ptrs = D + off_hz * N_CTX |
| | m_ptrs = M + off_hz * N_CTX |
| | |
| | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
| | dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
| | |
| | if EVEN_N_BLOCK: |
| | k = tl.load(k_ptrs) |
| | v = tl.load(v_ptrs) |
| | else: |
| | k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX) |
| | v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
| | dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
| | if EVEN_N_BLOCK: |
| | k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd) |
| | v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd) |
| | else: |
| | k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX) |
| | v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX) |
| |
|
| | |
| |
|
| | layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m |
| | start_l = tl.load(layout_ptr).to(tl.int32) |
| | end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32) |
| |
|
| | for row_idx_idx in range(start_l, end_l): |
| | row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32) |
| | start_m = row_idx * BLOCK_M |
| |
|
| | |
| | offs_m_curr = start_m + offs_m |
| | q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd) |
| | do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) |
| | dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) |
| |
|
| | |
| | if EVEN_M_BLOCK: |
| | q = tl.load(q_ptrs) |
| | else: |
| | q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX) |
| | |
| | |
| | qk = tl.dot(q, tl.trans(k)) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | if EVEN_M_BLOCK: |
| | q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) |
| | else: |
| | q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX) |
| | qk += tl.dot(q2, tl.trans(k2)) |
| |
|
| | qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf')) |
| |
|
| | if EVEN_M_BLOCK: |
| | m = tl.load(m_ptrs + offs_m_curr) |
| | else: |
| | m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) |
| | p = tl.exp(qk * sm_scale - m[:, None]) |
| |
|
| | |
| | if EVEN_M_BLOCK: |
| | do = tl.load(do_ptrs) |
| | else: |
| | do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | if EVEN_M_BLOCK: |
| | do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od) |
| | else: |
| | do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX) |
| |
|
| | dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2) |
| |
|
| | |
| | if EVEN_M_BLOCK: |
| | Di = tl.load(D_ptrs + offs_m_curr) |
| | else: |
| | Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) |
| | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] |
| | dp += tl.dot(do, tl.trans(v)) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | dp += tl.dot(do2, tl.trans(v2)) |
| |
|
| | |
| | ds = p * dp * sm_scale |
| | |
| | dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) |
| | if NUM_DBLOCKS >= 2: |
| | dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2) |
| |
|
| | |
| | dq = tl.dot(ds.to(Q.dtype.element_ty), k) |
| | if EVEN_M_BLOCK: |
| | tl.atomic_add(dq_ptrs, dq) |
| | else: |
| | tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2) |
| | dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od |
| | if EVEN_M_BLOCK: |
| | tl.atomic_add(dq_ptrs2, dq2) |
| | else: |
| | tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX) |
| |
|
| | |
| | dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) |
| | dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) |
| | if EVEN_N_BLOCK: |
| | tl.store(dv_ptrs, dv) |
| | tl.store(dk_ptrs, dk) |
| | else: |
| | tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX) |
| | tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX) |
| |
|
| | if NUM_DBLOCKS >= 2: |
| | dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od |
| | dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od |
| | if EVEN_N_BLOCK: |
| | tl.store(dv_ptrs2, dv2) |
| | tl.store(dk_ptrs2, dk2) |
| | else: |
| | tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX) |
| | tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX) |
| |
|
| |
|
| |
|
| | def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None): |
| | ''' |
| | :param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v. |
| | :param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor. |
| | Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all.. |
| | ''' |
| | assert q.shape[-1] == k.shape[-1] == v.shape[-1] |
| | assert k.shape[2] == v.shape[2] |
| | o = out if out is not None else torch.empty_like(q).contiguous() |
| | grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) |
| |
|
| | q_rounded_len = grid[0] * BLOCK_M |
| | tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) |
| |
|
| | if inference is None: |
| | inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad) |
| |
|
| | if inference: |
| | L, m = tmp, tmp |
| | else: |
| | L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) |
| | m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) |
| |
|
| | if layout_col_indices.dim() == 1: |
| | layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1) |
| | layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1) |
| |
|
| | assert q.shape[-1] in [64, 128] |
| | BLOCK_DMODEL = 64 |
| |
|
| | if num_warps is None: |
| | MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL) |
| | num_warps = max(1, 2 ** int(math.log2(MIN_D / 16))) |
| | |
| | else: |
| | assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.''' |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | _fwd_kernel[grid]( |
| | q, k, v, sm_scale, |
| | layout_crow_indices, |
| | layout_col_indices, |
| | layout_crow_indices.stride(0), layout_crow_indices.stride(1), |
| | layout_col_indices.stride(0), layout_col_indices.stride(1), |
| | tmp, L, m, |
| | o, |
| | q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
| | k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
| | v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
| | o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
| | q.shape[0], q.shape[1], k.shape[2], |
| | k.shape[2] - q.shape[2], |
| | q_rounded_len, |
| | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, |
| | BLOCK_DMODEL=BLOCK_DMODEL, |
| | EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0, |
| | EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 , |
| | INFERENCE=inference, |
| | NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL, |
| | num_warps=num_warps, |
| | num_stages=num_stages, |
| | ) |
| | if inference: |
| | L, m = None, None |
| |
|
| | ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices) |
| | ctx.BLOCK_M = BLOCK_M |
| | ctx.BLOCK_N = BLOCK_N |
| | ctx.BLOCK_DMODEL = BLOCK_DMODEL |
| | |
| | ctx.grid = grid |
| | ctx.sm_scale = sm_scale |
| | ctx.num_warps = num_warps |
| | ctx.num_stages = num_stages |
| | return o |
| |
|
| |
|
| | def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None): |
| | |
| | q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors |
| |
|
| | |
| | |
| | |
| |
|
| | if not do.is_contiguous(): |
| | do = do.contiguous() |
| | |
| | |
| | |
| |
|
| | if not o.is_contiguous(): |
| | |
| | raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.') |
| |
|
| |
|
| | if layout_ccol_indices.dim() == 1: |
| | layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1) |
| | layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1) |
| |
|
| | |
| | dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32) |
| | dk = dk if dk is not None else torch.empty_like(k) |
| | dv =dv if dv is not None else torch.empty_like(v) |
| | do_scaled = torch.empty_like(do) |
| | delta = torch.empty_like(l) |
| |
|
| | assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride() |
| |
|
| | _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( |
| | o, do, l, |
| | do_scaled, delta, |
| | k.shape[2], |
| | BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1], |
| | ) |
| |
|
| | grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1]) |
| |
|
| | _bwd_kernel[grid]( |
| | q, k, v, ctx.sm_scale, |
| | layout_ccol_indices, |
| | layout_row_indices, |
| | layout_ccol_indices.stride(0), layout_ccol_indices.stride(1), |
| | layout_row_indices.stride(0), layout_row_indices.stride(1), |
| | o, do_scaled, |
| | dq, dk, dv, |
| | l, m, |
| | delta, |
| | q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
| | k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
| | v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
| | o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
| | q.shape[0], q.shape[1], q.shape[2], |
| | ctx.grid[0], |
| | BLOCK_M=ctx.BLOCK_M, |
| | BLOCK_N=ctx.BLOCK_N, |
| | BLOCK_DMODEL=ctx.BLOCK_DMODEL, |
| | NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL, |
| | num_warps=ctx.num_warps, |
| | num_stages=1, |
| | ) |
| | return dq, dk, dv, None, None, None |
| |
|
| |
|
| | class _sparse_attention(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): |
| | BLOCK = 128 |
| | |
| | return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK) |
| |
|
| | @staticmethod |
| | def backward(ctx, do): |
| | |
| | q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors |
| | |
| | |
| | layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices)) |
| | return _backward(ctx, do, layout_ccol_indices, layout_row_indices) |
| |
|
| |
|
| |
|
| | |
| | class _sparse_attention_inference(_sparse_attention): |
| | |
| | @staticmethod |
| | def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): |
| | BLOCK = 128 |
| | return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK) |
| |
|
| |
|
| |
|
| | def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs): |
| | class _sparse_attention_config(_sparse_attention): |
| | @staticmethod |
| | def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): |
| | |
| | return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, |
| | **kwargs |
| | ) |
| | return _sparse_attention_config.apply |
| |
|
| |
|
| | @lru_cache(maxsize=8) |
| | def get_local_strided_sparse_attention_op( |
| | n_heads: int, |
| | max_seq_len:int, |
| | sparse_block_size: int=128, |
| | local_blocks: int=4, |
| | vert_stride: int=4, |
| | homo_head: bool=False, |
| | dtype=torch.bfloat16, |
| | device='cuda', |
| | active_head_range=None, |
| | verbose=True, |
| | **kwargs): |
| | ''' |
| | :param n_heads: total number of attention heads (regardless of tensor/model parallel) |
| | :param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences. |
| | :param sparse_block_size: sparse block size. Default to 128 |
| | :param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens. |
| | :param vert_stride: Default to 4. Meaning |
| | :param homo_head: if all head shared the same pattern. |
| | :param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads. |
| | Mainly for tensor/model parallelization where heads are splitted to different GPUs. |
| | ''' |
| |
|
| | if verbose: |
| | print((f'> new block_sparse_attn op constructed with config: ' |
| | f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, ' |
| | f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}')) |
| | |
| | _, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device, |
| | BLOCK=sparse_block_size, local_blocks=local_blocks, |
| | vert_stride=vert_stride, homo_head=homo_head, |
| | return_dense=False) |
| | if (not homo_head) and (active_head_range is not None): |
| | assert isinstance(active_head_range, tuple) |
| | assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.' |
| | h_start, h_end = active_head_range |
| | block_sparse_pattern = block_sparse_pattern[h_start:h_end] |
| | |
| | return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs) |
| |
|
| |
|
| | def get_sparse_attn_op( |
| | sparse_pattern: torch.tensor, |
| | sparse_block_size: int=128, |
| | kernel_block_size=128, |
| | qkv_format='q,k,v', |
| | **kwargs): |
| | ''' |
| | Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime, |
| | which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.) |
| | |
| | :param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`. |
| | This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention |
| | :param sparse_block_size: sparse block size. Default to 128 |
| | :param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size` |
| | :param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported. |
| | |
| | :param kwargs: keyward arguments passed to `_forward` |
| | ''' |
| | |
| |
|
| | assert qkv_format == 'q,k,v' |
| |
|
| | if kernel_block_size is None: |
| | kernel_block_size = sparse_block_size |
| | else: |
| | assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}." |
| | assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given" |
| |
|
| |
|
| | |
| | |
| | if sparse_block_size // kernel_block_size > 1: |
| | _mul = sparse_block_size // kernel_block_size |
| | |
| | sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul)) |
| | num_sparse_blocks = sparse_pattern.size(-1) |
| | block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] |
| | sparse_pattern *= block_causal_mask.type_as(sparse_pattern) |
| | |
| | |
| |
|
| | BLOCK_N = kernel_block_size |
| | NUM_BLOCK = sparse_pattern.size(-1) |
| | MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK |
| |
|
| | grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern) |
| | |
| | grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | max_cache_size = 1 if kwargs.get('inference', False) else 8 |
| |
|
| | @lru_cache(maxsize=max_cache_size) |
| | def get_backward_layout_by_block_len(block_len): |
| | assert block_len <= NUM_BLOCK |
| | if block_len == NUM_BLOCK: |
| | return (grand_layout_ccol_indices, grand_layout_row_indices) |
| | return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len]) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | class _q_k_v_sparse_attention(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, q, k, v, sm_scale): |
| | |
| | |
| | MIN_BLOCK_SIZE = 16 |
| | assert BLOCK_N >= MIN_BLOCK_SIZE |
| | BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N |
| |
|
| | |
| | K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size) |
| | |
| | Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N) |
| | |
| |
|
| | layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1] |
| | layout_col_indices = grand_layout_col_indices |
| | |
| |
|
| | return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, |
| | **kwargs |
| | ) |
| | @staticmethod |
| | def backward(ctx, do): |
| | q, k = ctx.saved_tensors[:2] |
| | assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.' |
| | |
| | block_len = triton.cdiv(do.shape[2], kernel_block_size) |
| | backward_layout = get_backward_layout_by_block_len(block_len) |
| | return _backward(ctx, do, *backward_layout)[:4] |
| |
|
| |
|
| | def _q_k_v_sparse_attention_fn(*args): |
| | return _q_k_v_sparse_attention.apply(*args) |
| |
|
| | _q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern |
| | _q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices |
| | _q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices |
| | _q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices |
| | _q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices |
| |
|
| | return _q_k_v_sparse_attention_fn |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | def blocksparse_flash_attn_padded_fwd( |
| | q, k, v, |
| | sm_scale, |
| | sparse_layout, |
| | *, |
| | left_paddings = None, |
| | seqlens = None, |
| | block_size = 64, |
| | max_seqlen = None |
| | ): |
| | ''' |
| | q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size) |
| | left_paddings: (batch, ), number of left paddings for each sample. |
| | seqlens: can be used to specify right padding. No need to specify if left_paddings is used. |
| | ''' |
| | batches, q_len, n_heads, head_size = q.shape |
| | _, k_len, n_kv_heads, _ = k.shape |
| |
|
| |
|
| | assert q.dim() == k.dim() == v.dim() == 4 |
| | assert q.size(2) % k.size(2) == 0 |
| | assert q.size(0) == k.size(0) and q.size(3) == k.size(3) |
| | assert k.shape == v.shape |
| | assert q_len == 1 or q_len == k_len, \ |
| | f'q length can only 1 for decoding for same as k length for prefilling.' |
| |
|
| | q_k_ratio = q.size(2) // k.size(2) |
| |
|
| | if max_seqlen: |
| | assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.' |
| |
|
| | |
| | out = q.new_zeros(q.shape) |
| |
|
| | layout_crow_indices, layout_col_indices = sparse_layout |
| | block_d = triton.next_power_of_2(head_size) |
| |
|
| | if left_paddings is not None: |
| | assert left_paddings.shape == (batches,) |
| | k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous() |
| | else: |
| | k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device) |
| |
|
| | if seqlens is not None: |
| | k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts) |
| | assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.' |
| | else: |
| | k_batch_ends = torch.zeros_like(k_batch_starts) + k_len |
| |
|
| | if q_len == 1: |
| | q_batch_starts = torch.zeros_like(k_batch_starts) |
| | q_batch_ends = q_batch_starts + 1 |
| | else: |
| | q_batch_starts = k_batch_starts |
| | q_batch_ends = k_batch_ends |
| |
|
| | |
| | q_lens = (q_batch_ends - q_batch_starts).cpu() |
| | n_blocks = (q_lens + block_size - 1) // block_size |
| |
|
| | q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], |
| | dtype=q_batch_starts.dtype, |
| | device=q_batch_starts.device) |
| | q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], |
| | dtype=q_batch_starts.dtype, |
| | device=q_batch_starts.device) |
| |
|
| | grid = (len(q_start_sids), n_heads) |
| |
|
| | _fwd_kernel_batch_inference[grid]( |
| | q, k, v, out, |
| | sm_scale, |
| | q_batch_starts, |
| | q_batch_ends, |
| | k_batch_starts, |
| | k_batch_ends, |
| | q_batch_ids, |
| | q_start_sids, |
| |
|
| | *q.stride(), |
| | *k.stride(), |
| | *v.stride(), |
| | *out.stride(), |
| |
|
| | layout_crow_indices, |
| | layout_col_indices, |
| | *layout_crow_indices.stride(), |
| | *layout_col_indices.stride(), |
| |
|
| | q_k_ratio, |
| | HAS_BATCH_DIM = True, |
| | D_HEAD = head_size, |
| | BLOCK_M = block_size, |
| | BLOCK_N = block_size, |
| | BLOCK_D = block_d, |
| | BLOCK_M_LOADING = 16 if q_len == 1 else block_size, |
| | EVEN_D = block_d == head_size, |
| | num_warps = 1 if q_len == 1 else 4, |
| | num_stages = 3 |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | def blocksparse_flash_attn_varlen_fwd( |
| | q, k, v, |
| | cu_seqlens_k, |
| | cu_seqlens_q, |
| | sm_scale, |
| | sparse_layout, |
| | *, |
| | block_size=64, |
| | max_seqlen = None |
| | ): |
| | |
| | _, n_heads, head_size = q.shape |
| | batch_size = cu_seqlens_k.size(0) - 1 |
| |
|
| |
|
| | |
| | assert q.dim() == k.dim() == v.dim() == 3 |
| | assert q.size(1) % k.size(1) == 0 |
| | assert q.size(2) == k.size(2) |
| | assert k.shape == v.shape |
| | assert cu_seqlens_k.dim() == 1 |
| |
|
| | q_k_ratio = q.size(1) // k.size(1) |
| |
|
| | if cu_seqlens_q is None: |
| | if q.size(0) == batch_size: |
| | cu_seqlens_q = torch.arange(0, batch_size + 1, |
| | dtype=cu_seqlens_k.dtype, |
| | device=cu_seqlens_k.device) |
| | elif q.size(0) == k.size(0): |
| | cu_seqlens_q = cu_seqlens_k |
| | else: |
| | raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.') |
| | else: |
| | assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) |
| |
|
| | |
| | q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() |
| | k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() |
| |
|
| | assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \ |
| | 'length of q should either be 1 (decoding) or same as k (prefilling).' |
| |
|
| | if max_seqlen: |
| | assert k_lens.max() <= max_seqlen |
| |
|
| | n_blocks = (q_lens + block_size - 1) // block_size |
| |
|
| | q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], |
| | dtype=cu_seqlens_q.dtype, |
| | device=cu_seqlens_q.device) |
| | q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], |
| | dtype=cu_seqlens_q.dtype, |
| | device=cu_seqlens_q.device) |
| |
|
| |
|
| | out = q.new_empty(q.shape) |
| | cu_seqlens_q = cu_seqlens_q.contiguous() |
| | cu_seqlens_k = cu_seqlens_k.contiguous() |
| |
|
| | layout_crow_indices, layout_col_indices = sparse_layout |
| | block_d = triton.next_power_of_2(head_size) |
| |
|
| | decoding_only = (q_lens == 1).all() |
| |
|
| | grid = (len(q_start_sids), n_heads) |
| |
|
| | _fwd_kernel_batch_inference[grid]( |
| | q, k, v, out, |
| | sm_scale, |
| | cu_seqlens_q[:-1], |
| | cu_seqlens_q[1:], |
| | cu_seqlens_k[:-1], |
| | cu_seqlens_k[1:], |
| | q_batch_ids, |
| | q_start_sids, |
| |
|
| | 0, *q.stride(), |
| | 0, *k.stride(), |
| | 0, *v.stride(), |
| | 0, *out.stride(), |
| |
|
| | layout_crow_indices, |
| | layout_col_indices, |
| | *layout_crow_indices.stride(), |
| | *layout_col_indices.stride(), |
| |
|
| | q_k_ratio, |
| | HAS_BATCH_DIM = False, |
| | D_HEAD = head_size, |
| | BLOCK_M = block_size, |
| | BLOCK_N = block_size, |
| | BLOCK_D = block_d, |
| | BLOCK_M_LOADING = 16 if decoding_only else block_size, |
| | EVEN_D = block_d == head_size, |
| | num_warps = 1 if decoding_only else 4, |
| | num_stages = 3 |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | @triton.jit |
| | def _fwd_kernel_inner( |
| | acc, l_i, m_i, |
| | q, Q, |
| | k_block_col_idx, |
| | layout_col_ptr, |
| | layout_col_stride_h, layout_col_stride_m, |
| | k_ptrs, |
| | v_ptrs, |
| | off_h, offs_m, offs_n, offs_d, |
| | stride_kt, stride_vt, |
| | sm_scale, |
| | k_seqlen, |
| | past_len, |
| | LAST_K_BLOCK: tl.constexpr, |
| | BLOCK_M_LOADING: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | D_HEAD: tl.constexpr, |
| | EVEN_D: tl.constexpr, |
| | M_LT_N: tl.constexpr |
| | ): |
| | k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32) |
| | start_n = k_block_id * BLOCK_N |
| | |
| | if LAST_K_BLOCK: |
| | if EVEN_D: |
| | k = tl.load(k_ptrs + start_n * stride_kt, |
| | mask=offs_n[None, :] + start_n < k_seqlen) |
| | else: |
| | |
| | k = tl.load(k_ptrs + start_n * stride_kt, |
| | mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD)) |
| | else: |
| | if EVEN_D: |
| | k = tl.load(k_ptrs + start_n * stride_kt) |
| | else: |
| | k = tl.load(k_ptrs + start_n * stride_kt, |
| | mask=offs_d[:, None] < D_HEAD) |
| |
|
| |
|
| | qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) |
| | qk += tl.dot(q, k) |
| |
|
| | qk *= sm_scale |
| |
|
| | |
| | if LAST_K_BLOCK | M_LT_N: |
| | qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) |
| |
|
| | |
| | m_ij = tl.max(qk, 1) |
| | p = tl.exp(qk - m_ij[:, None]) |
| |
|
| | l_ij = tl.sum(p, 1) |
| | |
| | m_i_new = tl.maximum(m_i, m_ij) |
| | alpha = tl.exp(m_i - m_i_new) |
| | beta = tl.exp(m_ij - m_i_new) |
| | l_i_new = alpha * l_i + beta * l_ij |
| | |
| | |
| | p_scale = beta / l_i_new |
| | p = p * p_scale[:, None] |
| | |
| | acc_scale = l_i / l_i_new * alpha |
| | acc = acc * acc_scale[:, None] |
| |
|
| | p = p.to(Q.dtype.element_ty) |
| | |
| | if LAST_K_BLOCK: |
| | if EVEN_D: |
| | v = tl.load(v_ptrs + start_n * stride_vt, |
| | mask=offs_n[:, None] + start_n < k_seqlen) |
| | else: |
| | v = tl.load(v_ptrs + start_n * stride_vt, |
| | mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD)) |
| | else: |
| | if EVEN_D: |
| | v = tl.load(v_ptrs + start_n * stride_vt) |
| | else: |
| | v = tl.load(v_ptrs + start_n * stride_vt, |
| | mask=offs_d[None, :] < D_HEAD) |
| |
|
| | acc += tl.dot(p, v) |
| | |
| | l_i = l_i_new |
| | m_i = m_i_new |
| | return acc, l_i, m_i |
| |
|
| |
|
| | @triton.heuristics( |
| | { |
| | 'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'], |
| | } |
| | ) |
| | @triton.jit |
| | def _fwd_kernel_batch_inference( |
| | Q, K, V, Out, |
| | |
| | sm_scale, |
| | q_batch_starts, |
| | q_batch_ends, |
| | k_batch_starts, |
| | k_batch_ends, |
| | q_batch_ids, |
| | q_start_sids, |
| | |
| | stride_qb, stride_qt, stride_qh, stride_qd, |
| | stride_kb, stride_kt, stride_kh, stride_kd, |
| | stride_vb, stride_vt, stride_vh, stride_vd, |
| | stride_ob, stride_ot, stride_oh, stride_od, |
| | |
| | layout_crow_ptr, |
| | layout_col_ptr, |
| | layout_crow_stride_h, layout_crow_stride_m, |
| | layout_col_stride_h, layout_col_stride_m, |
| | |
| | q_k_ratio, |
| | |
| | HAS_BATCH_DIM: tl.constexpr, |
| | D_HEAD: tl.constexpr, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_D: tl.constexpr, |
| | BLOCK_M_LOADING: tl.constexpr, |
| | EVEN_D: tl.constexpr, |
| | M_LT_N: tl.constexpr |
| | ): |
| | ''' |
| | NOTATION: |
| | pid: position id |
| | sid: storage id |
| | sbid: storage block id |
| | pbid: position block id |
| | offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) |
| | |
| | q and blocks in KV needs to be contiguous |
| | |
| | Arguments: |
| | kv_seq_lens: for compute past_len |
| | kv_storage_offsets: similar to block_tables in vllm, except it is dynamic. |
| | TODO: fix this |
| | |
| | TODO: |
| | Optimize grouped-attn |
| | |
| | CUDA graph support issue |
| | 1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...) |
| | since we mix prompt and decoing phase here, it can be more complex. |
| | need to set up diff cuda-graph for diff (off_zm, off_z) |
| | |
| | # indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding |
| | therefore, cu_seqlens_q, kv_seq_lens |
| | |
| | ''' |
| | off_zm = tl.program_id(0) |
| | off_h = tl.program_id(1) |
| |
|
| | off_h_for_kv = off_h // q_k_ratio |
| | off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) |
| | q_start_sid = tl.load(q_start_sids + off_zm) |
| | start_m = q_start_sid // BLOCK_M |
| |
|
| | if HAS_BATCH_DIM: |
| | Q += off_z * stride_qb |
| | K += off_z * stride_kb |
| | V += off_z * stride_vb |
| | Out += off_z * stride_ob |
| |
|
| | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) |
| | offs_n = tl.arange(0, BLOCK_N) |
| | offs_d = tl.arange(0, BLOCK_D) |
| |
|
| | q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) |
| | q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start |
| |
|
| | k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) |
| | k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start |
| |
|
| | past_len = k_seqlen - q_seqlen |
| |
|
| | Q += q_cu_start * stride_qt + off_h * stride_qh |
| | K += k_cu_start * stride_kt + off_h_for_kv * stride_kh |
| | V += k_cu_start * stride_vt + off_h_for_kv * stride_vh |
| | Out += q_cu_start * stride_ot + off_h * stride_oh |
| |
|
| | q_pbid = (past_len + q_start_sid) // BLOCK_M |
| |
|
| | if EVEN_D: |
| | q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, |
| | mask=offs_m[:, None] < q_seqlen) |
| | else: |
| | q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, |
| | mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), |
| | other=0) |
| |
|
| | sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m |
| |
|
| | |
| | k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) |
| | k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) |
| |
|
| | m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf') |
| | l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) |
| | acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) |
| |
|
| | k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd |
| | v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd |
| |
|
| | for k_block_col_idx in range(k_block_start, k_block_end - 1): |
| | acc, l_i, m_i = _fwd_kernel_inner( |
| | acc, l_i, m_i, |
| | q, Q, |
| | k_block_col_idx, |
| | layout_col_ptr, |
| | layout_col_stride_h, layout_col_stride_m, |
| | k_ptrs, |
| | v_ptrs, |
| | off_h, offs_m, offs_n, offs_d, |
| | stride_kt, stride_vt, |
| | sm_scale, |
| | k_seqlen, |
| | past_len, |
| | False, |
| | BLOCK_M_LOADING, |
| | BLOCK_N, |
| | D_HEAD, |
| | EVEN_D, |
| | M_LT_N |
| | ) |
| |
|
| | acc, l_i, m_i = _fwd_kernel_inner( |
| | acc, l_i, m_i, |
| | q, Q, |
| | k_block_end - 1, |
| | layout_col_ptr, |
| | layout_col_stride_h, layout_col_stride_m, |
| | k_ptrs, |
| | v_ptrs, |
| | off_h, offs_m, offs_n, offs_d, |
| | stride_kt, stride_vt, |
| | sm_scale, |
| | k_seqlen, |
| | past_len, |
| | True, |
| | BLOCK_M_LOADING, |
| | BLOCK_N, |
| | D_HEAD, |
| | EVEN_D, |
| | M_LT_N |
| | ) |
| |
|
| | |
| | if EVEN_D: |
| | tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, |
| | mask=offs_m[:, None] < q_seqlen) |
| | else: |
| | tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, |
| | mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD)) |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None): |
| | ''' |
| | q, k, v: shape=(batch, n_heads, seq, dim) |
| | ''' |
| | |
| | if sm_scale is None: |
| | sm_scale = math.sqrt(float(q.size(-1))) |
| |
|
| | if block_attn_mask is not None: |
| | assert attn_mask is None |
| | outs = [] |
| | for s in range(0, q.size(2), block_size): |
| | e = min(s + block_size, q.size(2)) |
| | q_block = q[:, :, s:e] |
| | attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale |
| | mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)] |
| | mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device)) |
| | mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0) |
| | attn = attn.masked_fill((1 - mask).bool(), float('-inf')) |
| | attn = attn.softmax(-1) |
| | out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e]) |
| | outs.append(out) |
| | torch_output = torch.cat(outs, dim=2) |
| | else: |
| | attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale |
| | |
| | if attn_mask is not None: |
| | attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf')) |
| | |
| |
|
| | attn = attn.softmax(-1) |
| | if do is not None: |
| | dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do) |
| | print(f'> torch_attn computed dv: {dv=}') |
| | torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v) |
| | return torch_output |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)]) |
| | def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True, |
| | sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None): |
| | Q_LEN = Q_LEN or N_CTX |
| | torch.manual_seed(20) |
| | q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) |
| | k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) |
| | v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) |
| |
|
| | if sm_scale is None: |
| | sm_scale = 1. / math.sqrt(D_HEAD) |
| |
|
| | |
| | |
| | sm_scale = 0.0078125 |
| | if backward: |
| | q.requires_grad_(), k.requires_grad_(), v.requires_grad_() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | dout = torch.randn_like(q).contiguous() |
| |
|
| | |
| | |
| |
|
| | mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size, |
| | local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True) |
| |
|
| | if sparse_attention_fn is None: |
| | sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX, |
| | sparse_block_size=sparse_block_size, |
| | local_blocks=local_blocks, |
| | vert_stride=vert_stride, |
| | homo_head=homo_head, |
| | device=q.device, |
| | dtype=q.dtype, |
| | kernel_block_size=kernel_block_size) |
| | |
| | ref_out = torch_attention(q, k, v, mask_dense, sm_scale) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | if backward: |
| | ref_out.backward(dout) |
| | ref_dv, v.grad = v.grad.clone(), None |
| | ref_dk, k.grad = k.grad.clone(), None |
| | ref_dq, q.grad = q.grad.clone(), None |
| |
|
| | tri_out = sparse_attention_fn(q, k, v, sm_scale) |
| |
|
| | decimal = 1 if dtype == torch.bfloat16 else 2 |
| | assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}' |
| |
|
| | if backward: |
| | tri_out.backward(dout) |
| | tri_dv, v.grad = v.grad.clone(), None |
| | tri_dk, k.grad = k.grad.clone(), None |
| | tri_dq, q.grad = q.grad.clone(), None |
| |
|
| | if backward: |
| | assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) |
| | assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) |
| | assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) |
| |
|
| | print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}') |
| |
|
| | |
| |
|
| | if __name__ == '__main__': |
| |
|
| | GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip() |
| | |
| | support_backward = True |
| |
|
| | |
| | |
| |
|
| | HAS_DENSE_TRITON_FLASH = False |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | try: |
| | from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func |
| | HAS_FLASH = True |
| | except BaseException: |
| | HAS_FLASH = False |
| | print('> cannot import flash_attn') |
| |
|
| |
|
| | |
| | BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 |
| | |
| |
|
| | BLOCK_SIZE = 64 |
| | LOCAl_BLOCKS = 8 |
| | VERT_STRIDE = 1 |
| | HOMO_HEAD = False |
| | sparse_type = 'home' if HOMO_HEAD else 'hetero' |
| | dtype = torch.bfloat16 |
| |
|
| |
|
| | modes = ['fwd', 'bwd'] if support_backward else ['fwd'] |
| |
|
| | configs = [triton.testing.Benchmark( |
| | x_names=['SEQ_LEN'], |
| | x_vals=[2**i for i in range(8, 16)], |
| | line_arg='provider', |
| | line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'], |
| | line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'], |
| | styles=[('red', '-'), ('blue', '-'), ('green', '-')], |
| | ylabel='ms', |
| | plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}', |
| | args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode} |
| | ) for mode in modes] |
| |
|
| |
|
| | @triton.testing.perf_report(configs) |
| | def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None): |
| | assert mode in ['fwd', 'bwd'] |
| | warmup = 25 |
| | rep = 100 |
| | N_CTX = SEQ_LEN |
| | if provider == 'triton': |
| | q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
| | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
| | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
| | sm_scale = 1.3 |
| | fn = lambda: triton_attention(q, k, v, sm_scale) |
| | if mode == 'bwd': |
| | o = fn() |
| | do = torch.randn_like(o) |
| | fn = lambda: o.backward(do, retain_graph=True) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| | if provider == 'triton_sparse': |
| | q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
| | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
| | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
| | sm_scale = 1.3 |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if sparse_attention_fn is None: |
| | |
| | sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN, |
| | local_blocks=LOCAl_BLOCKS, |
| | vert_stride=VERT_STRIDE, |
| | homo_head=HOMO_HEAD, |
| | sparse_block_size=BLOCK_SIZE, |
| | kernel_block_size=BLOCK_SIZE, |
| | device=q.device) |
| | |
| |
|
| | |
| | fn = lambda: sparse_attention_fn(q, k, v, sm_scale) |
| | if mode == 'bwd': |
| | o = fn() |
| | do = torch.randn_like(o) |
| | fn = lambda: o.backward(do, retain_graph=True) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| | if provider == 'flash': |
| | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) |
| | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) |
| | cu_seqlens[1:] = lengths.cumsum(0) |
| | qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) |
| | fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) |
| | if mode == 'bwd': |
| | o = fn() |
| | do = torch.randn_like(o) |
| | fn = lambda: o.backward(do, retain_graph=True) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 |
| |
|
| | BLOCK_SIZE = 64 |
| | LOCAl_BLOCKS = 8 |
| | VERT_STRIDE = 16 |
| | HOMO_HEAD = False |
| | sparse_type = 'home' if HOMO_HEAD else 'hetero' |
| | dtype = torch.bfloat16 |
| | MAX_N_CTX = 8192 |
| |
|
| | configs = [triton.testing.Benchmark( |
| | x_names=['PAST_LEN'], |
| | x_vals=[2**i - 1 for i in range(8, 14)], |
| | line_arg='provider', |
| | line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'], |
| | line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'], |
| | styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')], |
| | ylabel='ms', |
| | plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}', |
| | args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode} |
| | ) for mode in ['fwd']] |
| | @triton.testing.perf_report(configs) |
| | def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'): |
| | assert mode in ['fwd'] |
| | warmup = 25 |
| | rep = 100 |
| | N_CTX = PAST_LEN + Q_LEN |
| | if provider == 'torch': |
| | q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | sm_scale = 1.3 |
| | mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE, |
| | local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True) |
| |
|
| | fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| | if provider == 'triton_sparse': |
| | q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | sm_scale = 1.3 |
| | sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, |
| | local_blocks=LOCAl_BLOCKS, |
| | vert_stride=VERT_STRIDE, |
| | homo_head=HOMO_HEAD, |
| | sparse_block_size=BLOCK_SIZE, |
| | kernel_block_size=BLOCK_SIZE, |
| | device=q.device, |
| | inference=True) |
| |
|
| | fn = lambda: sparse_attention_fn(q, k, v, sm_scale) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| | if provider == 'triton_dense': |
| | q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | sm_scale = 1.3 |
| | sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, |
| | local_blocks=1, |
| | vert_stride=1, |
| | homo_head=True, |
| | sparse_block_size=BLOCK_SIZE, |
| | kernel_block_size=BLOCK_SIZE, |
| | device=q.device, |
| | inference=True) |
| |
|
| | fn = lambda: sparse_attention_fn(q, k, v, sm_scale) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| | if provider == 'flash': |
| | assert Q_LEN == 1 |
| | lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) |
| | cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) |
| | cu_seqlens[1:] = lengths.cumsum(0) |
| | cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32) |
| |
|
| | |
| | q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| | v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
| |
|
| | fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False) |
| | ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
| | return ms |
| |
|
| |
|
| | test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward) |
| | |
| |
|
| | bench_flash_attention_inference.run(save_path='.', print_data=True) |
| | exit() |
| | |
| | test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64, |
| | dtype=torch.bfloat16, homo_head=False, backward=support_backward) |
| | |
| | test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128, |
| | kernel_block_size=64, local_blocks=8, vert_stride=8) |
| | test_op(3, 2, 2047, 128, homo_head=False, backward=False) |
| |
|
| | |
| | test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64) |
| | |
| | |
| |
|
| | |
| | test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False, |
| | backward=support_backward, local_blocks=1, vert_stride=1) |
| |
|
| | |
| | test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward) |
| |
|
| | |
| | test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward) |
| | test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward) |
| |
|
| | |
| | test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False) |
| | test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | bench_flash_attention.run(save_path='.', print_data=True) |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|