近期在部署3d稀疏卷积,需要对Rulebook与weights的计算进行速度优化,先研究下cuda矩阵乘法,特此记录下:

CPU

void matrix_multiply_cpu(const float *A, const float *B, float *C, int M, int N, int K) {
    for (int i = 0; i < M; ++i)
        for (int j = 0; j < N; ++j) {
            float sum = 0;
            for (int h = 0; h < K; ++h) {
                sum += A[i * K + h] * B[h * N + j];
            }
            C[i * N + j] = sum;
        }
}

朴素的GPU

采用2D grid 2D block

// 矩阵乘法的逐点实现方式
// 对于矩阵A(m * k)和矩阵B(k * n),每个元素访问的次数分别是n与m。这里存在着对全局内存的多次访问
// 2D
__global__ void matrix_multiply_cuda_naive(float *a, float *b, float *c, int M, int N,int K) {
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;

    if (row >= M || col >= N)
        return;

    float value = 0.0;
    for (int i = 0; i < K; i++) {
        value += a[row * K + i] * b[i * N + col];
    }
    c[row * N + col] = value;
}

计算一次 FMA(乘累加)之前需要读一次 A 和读一次 B,众所周知,读取 Global Memory 的代价很大,通常都需要几百个 cycle(时钟周期),而计算一次 FMA 通常只需要几个 cycle,大量的时间被花费在了访存上

共享内存

可以将 A 和 B 矩阵先搬运到 Shared Memory(SM 中低延迟的 on-chip memory,block 内线程共享,附 NVIDIA GPU 内存结构图)中降低访存的开销

法1

每个线程计算一个数:

// 矩阵乘法分块
// 把数据搬到更快的存储器中(比如共享内存),共享内存的大小有限,利用分块实现对共享内存的利用
// grid : (M/BLOCK_SIZE_K,N/BLOCK_SIZE_K)   block : (BLOCK_SIZE_K,BLOCK_SIZE_K)
// template <const int BLOCK_SIZE_K>
__global__ void SgemmV1(float *A, float *B, float *C, int M, int N,int K) {
    int row = threadIdx.x + blockIdx.x * blockDim.x;
    int col = threadIdx.y + blockIdx.y * blockDim.y;

    __shared__ float smem_a[BLOCK_SIZE_K][BLOCK_SIZE_K]; 
    __shared__ float smem_b[BLOCK_SIZE_K][BLOCK_SIZE_K]; 

    // 每个block负责C中一个维度为的小矩阵块的计算,计算中一共有k(K/BLOCK_SIZE_K)次迭代
    // 每一次迭代都需要读取A中一个维度为BLOCK_SIZE_K*BLOCK_SIZE_K的小矩阵块和B中一个维度为BLOCK_SIZE_K*BLOCK_SIZE_K的小矩阵块
    float sum = 0;
    for(int i = 0; i <= K / BLOCK_SIZE_K; i++){
        int ida = row * K + i * BLOCK_SIZE_K + threadIdx.y; // A数据的索引

        if (row < M && BLOCK_SIZE_K * i + threadIdx.y < K) {
            smem_a[threadIdx.x][threadIdx.y] = A[ida];
        } else {
            smem_a[threadIdx.x][threadIdx.y] = 0;
        }

        int idb = (threadIdx.x + i * BLOCK_SIZE_K) * N + col; // B数据的索引
        if (col < N && BLOCK_SIZE_K * i + threadIdx.x < K) {
            smem_b[threadIdx.x][threadIdx.y] = B[idb];
        } else {
            smem_b[threadIdx.x][threadIdx.y] = 0;
        }

        __syncthreads(); // 等待线程块的共享内存写入数据
#pragma unroll
        for (int i = 0; i < BLOCK_SIZE_K; i++) {
            sum += smem_a[threadIdx.x][i] * smem_b[i][threadIdx.y];
        }
        __syncthreads();
    }

    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

法2

法1只能将访存代价从几百 cycle 降低到几十 cycle,并不改变问题的本质。问题的关键在于主体循环由两条 Load 指令与一条 FMA 指令构成,计算指令只占总体的 1/3,计算访存比过低,最终导致了访存延迟不能被隐藏,从而性能不理想。
一个线程计算TM*TN的小方块:

打开思路,若一个 thread 并不只计算一个结果,而是计算 多 个结果,并且使用 Shared Memory 优化

为了方便起见假设M、N、K可以整除BM(128)、BN(128)、BK(8),同时考虑边缘情况,M,N可以取任意值,代码如下所示:

template <const int BM, // bm 128
          const int BK, // bk 8
          const int BN, // bn 128
          const int TM, // rm 8
          const int TN  // rn 8
          >
__global__ void SgemmV2(float *__restrict__ a, float *__restrict__ b,
                        float *__restrict__ c, const int M, const int N,
                        const int K) {
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int tid = tx * blockDim.y + ty;

    __shared__ float s_a[BM][BK];
    __shared__ float s_b[BK][BN];

    float r_c[TM][TN] = {0.0}; // 8 * 8
    // 总共256线程,128行数据,每行2个线程
    // 用float4读取全局数据到共享内存,该线程块内计算4个位置坐标
    int load_a_smem_m = tid >> 1; // 当前线程搬运的a数据横坐标  tid/2 0或1
    int load_a_smem_k = (tid & 1) << 2; // 当前线程搬运a数据的竖坐标 tid % 2 * 4
    // b搬运一行数据需要32 = 128 / 4 线程
    int load_b_smem_k = tid >> 5; // 当前线程搬运b数据的横坐标 tid / 32
    int load_b_smem_n = (tid & 31) << 2; // 当前线程搬运b数据的纵坐标 tid %32 *4

    // by N方向第几个线程块
    int load_a_gmem_m = bx * BM + load_a_smem_m; // 全局横坐标(bx M方向第几个线程块)
    // bx N方向第几个线程块
    int load_b_gmem_n = by * BN + load_b_smem_n; // 全局竖坐标 by N方向第几个线程块

    // 把线程块对应的数据搬运到共享内存
    for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
        // 搬运A数据
        int load_a_gmem_k = bk * BK + load_a_smem_k; // 当前block的竖直坐标
        if (load_a_gmem_m < M) {
            int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);; // A数据当前线程对应的索引地址
            if (load_a_gmem_k + 3 < K) {
                s_a[load_a_smem_m][load_a_smem_k + 0] = a[load_a_gmem_addr + 0];
                s_a[load_a_smem_m][load_a_smem_k + 1] = a[load_a_gmem_addr + 1];
                s_a[load_a_smem_m][load_a_smem_k + 2] = a[load_a_gmem_addr + 2];
                s_a[load_a_smem_m][load_a_smem_k + 3] = a[load_a_gmem_addr + 3];
            } else {
                for (int i = 0; i < K - load_a_gmem_k; i++)
                    s_a[load_a_smem_m][load_a_smem_k + i] = a[load_a_gmem_addr + i];
            }
        }

        // 搬运B数据
        int load_b_gmem_k = bk * BK + load_b_smem_k; // b数据对应的横坐标
        if (load_b_gmem_k < K) {
            int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N); // B数据当前线程对应的索引地址
            if (load_b_gmem_n + 3 < N) {
                s_b[load_b_smem_k][load_b_smem_n + 0] = b[load_b_gmem_addr + 0];
                s_b[load_b_smem_k][load_b_smem_n + 1] = b[load_b_gmem_addr + 1];
                s_b[load_b_smem_k][load_b_smem_n + 2] = b[load_b_gmem_addr + 2];
                s_b[load_b_smem_k][load_b_smem_n + 3] = b[load_b_gmem_addr + 3];

            } else {
                for (int i = 0; i < N - load_b_gmem_n; i++)
                    s_b[load_b_smem_k][load_b_smem_n + i] = b[load_b_gmem_addr + i];
            }
        }
        __syncthreads();

#pragma unroll
        for (int k = 0; k < BK; k++) {
#pragma unroll
            for (int m = 0; m < TM; m++) {
#pragma unroll
                for (int n = 0; n < TN; n++) {
                    int comp_a_smem_m = tx * TM + m;
                    int comp_b_smem_n = ty * TN + n;
                    r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
                }
            }
        }

        __syncthreads();
    }

#pragma unroll
    for (int i = 0; i < TM; i++) {
        int store_c_gmem_m = bx * BM + tx * TM + i; // 全局横坐标
#pragma unroll
        for (int j = 0; j < TN; j += 1) {
            int store_c_gmem_n = by * BN + ty * TN + j; // 全局纵坐标
            if (store_c_gmem_m < M && store_c_gmem_n < N) {
                int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
                    c[store_c_gmem_addr] = r_c[i][j];
            }
        }
    }
}

处理Bank Conflict

一个线程的计算过程如下图所示,每次从Shared memory中取矩阵A的长度为TM的向量和矩阵B的长度为TN的向量,这两个向量做外积并累加到部分和中,一次外积共TM * TN次乘累加,一共需要循环BK次取数和外积。在每一次从Shared Memory load的过程中,存在着显而易见的Bank Conflict:

  • 取矩阵A需要取一个列向量,而矩阵A在Shared Memory中是按行存储的;
  • 在TM = TN = 8的情况下,无论矩阵A还是矩阵B,从Shared Memory中取数时需要取连续的8个数,即便用LDS.128指令一条指令取四个数,也需要两条指令,由于一个线程的两条load指令的地址是连续的,那么同一个Warp不同线程的同一条load指令的访存地址就是被间隔开的,便存在着Bank Conflict。

为了解决上述的两点Shared Memory的Bank Conflict,采用了一下两点优化:

  • 为矩阵A分配Shared Memory时形状分配为[BK][BM],也就是让矩阵A在Shared Memory中按列存储
  • 将原本每个线程负责计算的TM _ TN的矩阵C,划分为下图中这样的四块4 _ 4的矩阵C(实验中实测划分成两块,也就是解决A/B中一个矩阵的Bank Confilict就足够,划分成四块并没有比两块带来更高的性能)

注意M、N要为4的倍数,M可以为任意数

template <const int BM, // bm 128
          const int BK, // bk 8
          const int BN, // bn 128
          const int TM, // rm 8
          const int TN  // rn 8
          >
__global__ void SgemmV6(float *__restrict__ a, float *__restrict__ b,
                        float *__restrict__ c, const int M, const int N,
                        const int K) {
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int tid = ty * blockDim.x + tx;

    __shared__ float s_a[BK][BM];
    __shared__ float s_b[BK][BN];

    float r_load_a[4];
    float r_load_b[4];
    float r_comp_a[TM];
    float r_comp_b[TN];
    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1; // 当前线程搬运的a数据横坐标  tid/2 0或1
    int load_a_smem_k = (tid & 1) << 2; // 当前线程搬运a数据的竖坐标 tid % 2 * 4
    int load_b_smem_k = tid >> 5; // 当前线程搬运b数据的横坐标 tid / 32
    int load_b_smem_n = (tid & 31) << 2; // 当前线程搬运b数据的纵坐标 tid %32 *4

    int load_a_gmem_m = by * BM + load_a_smem_m;
    int load_b_gmem_n = bx * BN + load_b_smem_n;

    for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
        if (load_a_gmem_m < M) {
            // 需要先对A进行一次转置,先将数据存储在寄存器中,数据按行取,按列存
            int load_a_gmem_k = bk * BK + load_a_smem_k;
            int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
            FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        }
        s_a[load_a_smem_k][load_a_smem_m] = r_load_a[0];
        s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];

        // 数据B复制到共享内存
        int load_b_gmem_k = bk * BK + load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        if (load_b_gmem_n < N) {
            FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
        }

        __syncthreads();

// 避免bank冲突
#pragma unroll
        for (int tk = 0; tk < BK; tk++) {
            // 128*8 每行2个线程  tx * TM / 2  表示数据A对应线程块内的局部横坐标
            FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2]);
            FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
            // ty * TN / 2   ty * TN / 2 表示数据B对应线程块内的局部坐标坐标
            // LDS.128访问share menory一条指令每个thread是4个32bit数,share
            // memory 一拍做多只能处理8个thread的LDS.128
            FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2]);
            FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);

#pragma unroll
            for (int tm = 0; tm < TM; tm++) {
#pragma unroll
                for (int tn = 0; tn < TN; tn++) {
                    r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
                }
            }
        }

        __syncthreads();
    }

#pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        if (store_c_gmem_n  < N) {
            FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
            // c[store_c_gmem_addr + 0] = r_c[i][0];
            // c[store_c_gmem_addr + 1] = r_c[i][1];
            // c[store_c_gmem_addr + 2] = r_c[i][2];
            // c[store_c_gmem_addr + 3] = r_c[i][3];
        }
        if (store_c_gmem_n + BN / 2  < N) {
            FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
            // c[store_c_gmem_addr + 0 + BN / 2] = r_c[i][4 + 0];
            // c[store_c_gmem_addr + 1 + BN / 2] = r_c[i][4 + 1];
            // c[store_c_gmem_addr + 2 + BN / 2] = r_c[i][4 + 2];
            // c[store_c_gmem_addr + 3 + BN / 2] = r_c[i][4 + 3];
        }
    }
    // 保证N为4的倍数,使用FLOAT4读取,可以有效避免bank冲突,不然速度会慢很多
#pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        if (store_c_gmem_n + 4 < N) {
            FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
            // c[store_c_gmem_addr + 0] = r_c[i + TM / 2][0];
            // c[store_c_gmem_addr + 1] = r_c[i + TM / 2][1];
            // c[store_c_gmem_addr + 2] = r_c[i + TM / 2][2];
            // c[store_c_gmem_addr + 3] = r_c[i + TM / 2][3];
        }

        if (store_c_gmem_n + BN / 2  < N) {
            FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
            // c[store_c_gmem_addr + 0 + BN / 2] = r_c[i + TM / 2][4 + 0];
            // c[store_c_gmem_addr + 1 + BN / 2] = r_c[i + TM / 2][4 + 1];
            // c[store_c_gmem_addr + 2 + BN / 2] = r_c[i + TM / 2][4 + 2];
            // c[store_c_gmem_addr + 3 + BN / 2] = r_c[i + TM / 2][4 + 3];
        }
    }
}

参考: