近期在部署3d稀疏卷积,需要对Rulebook与weights的计算进行速度优化,先研究下cuda矩阵乘法,特此记录下:
CPU
void matrix_multiply_cpu(const float *A, const float *B, float *C, int M, int N, int K) {
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j) {
float sum = 0;
for (int h = 0; h < K; ++h) {
sum += A[i * K + h] * B[h * N + j];
}
C[i * N + j] = sum;
}
}
朴素的GPU
采用2D grid 2D block
// 矩阵乘法的逐点实现方式
// 对于矩阵A(m * k)和矩阵B(k * n),每个元素访问的次数分别是n与m。这里存在着对全局内存的多次访问
// 2D
__global__ void matrix_multiply_cuda_naive(float *a, float *b, float *c, int M, int N,int K) {
int row = threadIdx.x + blockIdx.x * blockDim.x;
int col = threadIdx.y + blockIdx.y * blockDim.y;
if (row >= M || col >= N)
return;
float value = 0.0;
for (int i = 0; i < K; i++) {
value += a[row * K + i] * b[i * N + col];
}
c[row * N + col] = value;
}
计算一次 FMA(乘累加)之前需要读一次 A 和读一次 B,众所周知,读取 Global Memory 的代价很大,通常都需要几百个 cycle(时钟周期),而计算一次 FMA 通常只需要几个 cycle,大量的时间被花费在了访存上
共享内存
可以将 A 和 B 矩阵先搬运到 Shared Memory(SM 中低延迟的 on-chip memory,block 内线程共享,附 NVIDIA GPU 内存结构图)中降低访存的开销
法1
每个线程计算一个数:
// 矩阵乘法分块
// 把数据搬到更快的存储器中(比如共享内存),共享内存的大小有限,利用分块实现对共享内存的利用
// grid : (M/BLOCK_SIZE_K,N/BLOCK_SIZE_K) block : (BLOCK_SIZE_K,BLOCK_SIZE_K)
// template <const int BLOCK_SIZE_K>
__global__ void SgemmV1(float *A, float *B, float *C, int M, int N,int K) {
int row = threadIdx.x + blockIdx.x * blockDim.x;
int col = threadIdx.y + blockIdx.y * blockDim.y;
__shared__ float smem_a[BLOCK_SIZE_K][BLOCK_SIZE_K];
__shared__ float smem_b[BLOCK_SIZE_K][BLOCK_SIZE_K];
// 每个block负责C中一个维度为的小矩阵块的计算,计算中一共有k(K/BLOCK_SIZE_K)次迭代
// 每一次迭代都需要读取A中一个维度为BLOCK_SIZE_K*BLOCK_SIZE_K的小矩阵块和B中一个维度为BLOCK_SIZE_K*BLOCK_SIZE_K的小矩阵块
float sum = 0;
for(int i = 0; i <= K / BLOCK_SIZE_K; i++){
int ida = row * K + i * BLOCK_SIZE_K + threadIdx.y; // A数据的索引
if (row < M && BLOCK_SIZE_K * i + threadIdx.y < K) {
smem_a[threadIdx.x][threadIdx.y] = A[ida];
} else {
smem_a[threadIdx.x][threadIdx.y] = 0;
}
int idb = (threadIdx.x + i * BLOCK_SIZE_K) * N + col; // B数据的索引
if (col < N && BLOCK_SIZE_K * i + threadIdx.x < K) {
smem_b[threadIdx.x][threadIdx.y] = B[idb];
} else {
smem_b[threadIdx.x][threadIdx.y] = 0;
}
__syncthreads(); // 等待线程块的共享内存写入数据
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i++) {
sum += smem_a[threadIdx.x][i] * smem_b[i][threadIdx.y];
}
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = sum;
}
}
法2
法1只能将访存代价从几百 cycle 降低到几十 cycle,并不改变问题的本质。问题的关键在于主体循环由两条 Load 指令与一条 FMA 指令构成,计算指令只占总体的 1/3,计算访存比过低,最终导致了访存延迟不能被隐藏,从而性能不理想。
一个线程计算TM*TN
的小方块:
打开思路,若一个 thread 并不只计算一个结果,而是计算 多 个结果,并且使用 Shared Memory 优化
为了方便起见假设M、N、K
可以整除BM(128)、BN(128)、BK(8)
,同时考虑边缘情况,M,N
可以取任意值,代码如下所示:
template <const int BM, // bm 128
const int BK, // bk 8
const int BN, // bn 128
const int TM, // rm 8
const int TN // rn 8
>
__global__ void SgemmV2(float *__restrict__ a, float *__restrict__ b,
float *__restrict__ c, const int M, const int N,
const int K) {
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = tx * blockDim.y + ty;
__shared__ float s_a[BM][BK];
__shared__ float s_b[BK][BN];
float r_c[TM][TN] = {0.0}; // 8 * 8
// 总共256线程,128行数据,每行2个线程
// 用float4读取全局数据到共享内存,该线程块内计算4个位置坐标
int load_a_smem_m = tid >> 1; // 当前线程搬运的a数据横坐标 tid/2 0或1
int load_a_smem_k = (tid & 1) << 2; // 当前线程搬运a数据的竖坐标 tid % 2 * 4
// b搬运一行数据需要32 = 128 / 4 线程
int load_b_smem_k = tid >> 5; // 当前线程搬运b数据的横坐标 tid / 32
int load_b_smem_n = (tid & 31) << 2; // 当前线程搬运b数据的纵坐标 tid %32 *4
// by N方向第几个线程块
int load_a_gmem_m = bx * BM + load_a_smem_m; // 全局横坐标(bx M方向第几个线程块)
// bx N方向第几个线程块
int load_b_gmem_n = by * BN + load_b_smem_n; // 全局竖坐标 by N方向第几个线程块
// 把线程块对应的数据搬运到共享内存
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
// 搬运A数据
int load_a_gmem_k = bk * BK + load_a_smem_k; // 当前block的竖直坐标
if (load_a_gmem_m < M) {
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);; // A数据当前线程对应的索引地址
if (load_a_gmem_k + 3 < K) {
s_a[load_a_smem_m][load_a_smem_k + 0] = a[load_a_gmem_addr + 0];
s_a[load_a_smem_m][load_a_smem_k + 1] = a[load_a_gmem_addr + 1];
s_a[load_a_smem_m][load_a_smem_k + 2] = a[load_a_gmem_addr + 2];
s_a[load_a_smem_m][load_a_smem_k + 3] = a[load_a_gmem_addr + 3];
} else {
for (int i = 0; i < K - load_a_gmem_k; i++)
s_a[load_a_smem_m][load_a_smem_k + i] = a[load_a_gmem_addr + i];
}
}
// 搬运B数据
int load_b_gmem_k = bk * BK + load_b_smem_k; // b数据对应的横坐标
if (load_b_gmem_k < K) {
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N); // B数据当前线程对应的索引地址
if (load_b_gmem_n + 3 < N) {
s_b[load_b_smem_k][load_b_smem_n + 0] = b[load_b_gmem_addr + 0];
s_b[load_b_smem_k][load_b_smem_n + 1] = b[load_b_gmem_addr + 1];
s_b[load_b_smem_k][load_b_smem_n + 2] = b[load_b_gmem_addr + 2];
s_b[load_b_smem_k][load_b_smem_n + 3] = b[load_b_gmem_addr + 3];
} else {
for (int i = 0; i < N - load_b_gmem_n; i++)
s_b[load_b_smem_k][load_b_smem_n + i] = b[load_b_gmem_addr + i];
}
}
__syncthreads();
#pragma unroll
for (int k = 0; k < BK; k++) {
#pragma unroll
for (int m = 0; m < TM; m++) {
#pragma unroll
for (int n = 0; n < TN; n++) {
int comp_a_smem_m = tx * TM + m;
int comp_b_smem_n = ty * TN + n;
r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < TM; i++) {
int store_c_gmem_m = bx * BM + tx * TM + i; // 全局横坐标
#pragma unroll
for (int j = 0; j < TN; j += 1) {
int store_c_gmem_n = by * BN + ty * TN + j; // 全局纵坐标
if (store_c_gmem_m < M && store_c_gmem_n < N) {
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
c[store_c_gmem_addr] = r_c[i][j];
}
}
}
}
处理Bank Conflict
一个线程的计算过程如下图所示,每次从Shared memory中取矩阵A的长度为TM的向量和矩阵B的长度为TN的向量,这两个向量做外积并累加到部分和中,一次外积共TM * TN次乘累加,一共需要循环BK次取数和外积。在每一次从Shared Memory load的过程中,存在着显而易见的Bank Conflict:
- 取矩阵A需要取一个列向量,而矩阵A在Shared Memory中是按行存储的;
- 在TM = TN = 8的情况下,无论矩阵A还是矩阵B,从Shared Memory中取数时需要取连续的8个数,即便用LDS.128指令一条指令取四个数,也需要两条指令,由于一个线程的两条load指令的地址是连续的,那么同一个Warp不同线程的同一条load指令的访存地址就是被间隔开的,便存在着Bank Conflict。
为了解决上述的两点Shared Memory的Bank Conflict,采用了一下两点优化:
- 为矩阵A分配Shared Memory时形状分配为[BK][BM],也就是让矩阵A在Shared Memory中按列存储
- 将原本每个线程负责计算的TM _ TN的矩阵C,划分为下图中这样的四块4 _ 4的矩阵C(实验中实测划分成两块,也就是解决A/B中一个矩阵的Bank Confilict就足够,划分成四块并没有比两块带来更高的性能)
注意M、N要为4的倍数,M可以为任意数
template <const int BM, // bm 128
const int BK, // bk 8
const int BN, // bn 128
const int TM, // rm 8
const int TN // rn 8
>
__global__ void SgemmV6(float *__restrict__ a, float *__restrict__ b,
float *__restrict__ c, const int M, const int N,
const int K) {
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;
__shared__ float s_a[BK][BM];
__shared__ float s_b[BK][BN];
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1; // 当前线程搬运的a数据横坐标 tid/2 0或1
int load_a_smem_k = (tid & 1) << 2; // 当前线程搬运a数据的竖坐标 tid % 2 * 4
int load_b_smem_k = tid >> 5; // 当前线程搬运b数据的横坐标 tid / 32
int load_b_smem_n = (tid & 31) << 2; // 当前线程搬运b数据的纵坐标 tid %32 *4
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
if (load_a_gmem_m < M) {
// 需要先对A进行一次转置,先将数据存储在寄存器中,数据按行取,按列存
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
}
s_a[load_a_smem_k][load_a_smem_m] = r_load_a[0];
s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
// 数据B复制到共享内存
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
if (load_b_gmem_n < N) {
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
}
__syncthreads();
// 避免bank冲突
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
// 128*8 每行2个线程 tx * TM / 2 表示数据A对应线程块内的局部横坐标
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
// ty * TN / 2 ty * TN / 2 表示数据B对应线程块内的局部坐标坐标
// LDS.128访问share menory一条指令每个thread是4个32bit数,share
// memory 一拍做多只能处理8个thread的LDS.128
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
if (store_c_gmem_n < N) {
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
// c[store_c_gmem_addr + 0] = r_c[i][0];
// c[store_c_gmem_addr + 1] = r_c[i][1];
// c[store_c_gmem_addr + 2] = r_c[i][2];
// c[store_c_gmem_addr + 3] = r_c[i][3];
}
if (store_c_gmem_n + BN / 2 < N) {
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
// c[store_c_gmem_addr + 0 + BN / 2] = r_c[i][4 + 0];
// c[store_c_gmem_addr + 1 + BN / 2] = r_c[i][4 + 1];
// c[store_c_gmem_addr + 2 + BN / 2] = r_c[i][4 + 2];
// c[store_c_gmem_addr + 3 + BN / 2] = r_c[i][4 + 3];
}
}
// 保证N为4的倍数,使用FLOAT4读取,可以有效避免bank冲突,不然速度会慢很多
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
if (store_c_gmem_n + 4 < N) {
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
// c[store_c_gmem_addr + 0] = r_c[i + TM / 2][0];
// c[store_c_gmem_addr + 1] = r_c[i + TM / 2][1];
// c[store_c_gmem_addr + 2] = r_c[i + TM / 2][2];
// c[store_c_gmem_addr + 3] = r_c[i + TM / 2][3];
}
if (store_c_gmem_n + BN / 2 < N) {
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
// c[store_c_gmem_addr + 0 + BN / 2] = r_c[i + TM / 2][4 + 0];
// c[store_c_gmem_addr + 1 + BN / 2] = r_c[i + TM / 2][4 + 1];
// c[store_c_gmem_addr + 2 + BN / 2] = r_c[i + TM / 2][4 + 2];
// c[store_c_gmem_addr + 3 + BN / 2] = r_c[i + TM / 2][4 + 3];
}
}
}
参考:
评论(0)
您还未登录,请登录后发表或查看评论