CUDA 007 - GEMM
In this post, I’ll be discussing the GEMM (General Matrix Multiplication) operation in CUDA. My goal is to provide a comprehensive understanding of GEMM and its implementation in CUDA.
Matrix Multiplication is a fundamental operation in linear algebra and is widely used in deep-learning models. Inside Large Language Models(LLMs), the attention QKV projections of transformers and linear layers in MLP blocks spend a significant amount of time on matrix multiplication. Implementing GEMM efficiently can significantly improve the performance of these models during inference and training.
What is GEMM?
The GEMM operation is defined as the multiplication of two matrices A and B, resulting in a matrix C. Mathematically, it can be represented as:
C = alpha * A * B + beta * C
where alpha and beta are scalar values, A is a matrix of dimensions M x K (rows x columns), B is a matrix of dimensions K x N, and C is the resulting matrix of dimensions M x N.
naïve GEMM
The most straightforward way to implement GEMM in CUDA is to use a thread for each element of the resulting matrix C. Each thread will compute a single element by multiplying the -th row of matrix A with the -th column of matrix B and then applying scaling factors alpha and beta. Here’s an example of how you can implement it:
__global__ void gemm_naive(const int M, const int N, const int K, const float alpha,
const float* A, const float* B, const float beta, float* C) {
const unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
const unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float t = 0.0;
for (int i = 0; i < K; i++) {
t += A[row * K + i] * B[i * N + col];
}
C[row * N + col] = alpha * t + beta * C[row * N + col];
}
}
In this code, each thread loads a row of matrix A and a column of matrix B, performs the dot product to compute , as visualized below:
All threads in a warp are basically working on the same row of matrix C, so all threads in a warp will read the same row of matrix A, and this is efficient because the GPU can boardcast the same element to all threads in the warp. Each thread in a warp reads a different column of matrix B, and the memory accesses are coalesced. Although the implementation is simple, it has taken advantage of some GPU’s features, such as memory coalescing and warp broadcasting. However, it is suffering from low arithmetic intensity.
Every loop iteration, each thread reads two 32-bit float elements from global memory, and performs a single floating-point muliply-add operation. This results in low arithmetic intensity, which means that the GPU’s compute resources are not fully utilized due to the imbalance between memory access and computation.
I run this code on my RTX 5070 with M, N and K all equal to 4096. You can see the performance of this kernel in the performace section at the end of this post.
2D block tiling GEMM
To improve the arithmetic intensity, we can use a technique called 2D block tiling. This method involves dividing matrix A and B into smaller blocks (tiles) that are processed in parallel by multiple threads within each thread block.
Here’s a picture to illustrate why 2D block tiling can increase arithmetic intensity:
In naïve GEMM, each thread read one row and one column from matrix A and B, and caclulate one element of matrix C. If we read multi rows and multi columns into shared memory, we can caclulate a tile of elements of matrix C. For example, if we read 4 rows and 4 columns into shared memory, we can caclulate 16 elements of matrix C. In this way, we only read 4 rows and columns from global memory, but we can caclulate 16 elements of matrix C, and we can increase the arithmetic intensity 4 times.
In real siutation, the K dimension of matrix A and B can be large, we can not load the whole rows and columns into shared memory. This can be easily solved by using tiling method. Here is how it works:
In the K dimension, we read a tile of rows and columns in a loop. In each iteration, we calculate the dot product of one row and column, and accumulate the result to the corresponding element of matrix C.
One way to think about this is that we can split matrix A and B into small blocks, and do submatrix multiplication.
submatrix multiplication, each submatrix has size of 4x4
This is why we can split the rows and columns into tiles in K dimension, and load them into shared memory iteratively.
Now, let’s see how to implement it in CUDA. For simplicity, we assume that M, N, and K are multiples of 16, and we use 16x16 block size and 16x16 tile size. This is not general, but it can help us understand the idea.
template<int TILE_DIM = 16>
__global__ void gemm_block_tiling_v1(const int M, const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C) {
const unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
const unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ float As[TILE_DIM][TILE_DIM];
__shared__ float Bs[TILE_DIM][TILE_DIM];
__syncthreads();
float t = 0.0;
for (int i = 0; i < K; i += TILE_DIM) {
As[threadIdx.y][threadIdx.x] = A[row * K + i + threadIdx.x];
Bs[threadIdx.y][threadIdx.x] = B[(i + threadIdx.y) * N + col];
__syncthreads();
for (int k = 0; k < TILE_DIM; k++) {
t += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
__syncthreads();
}
C[row * N + col] = alpha * t + beta * C[row * N + col];
}
We can launch this kernel with the following code, where cdiv is a function to calculate the ceiling division. The block size is same as the tile size.
constexpr int TILE_DIM = 16;
dim3 threads = {16, 16};
dim3 blocks = {cdiv(N, TILE_DIM), cdiv(M, TILE_DIM)};
ASSERT(TILE_DIM == threads.x && TILE_DIM == threads.y);
gemm_block_tiling_v1<TILE_DIM><<<blocks, threads>>>(M, N, K, alpha, A, B, beta, C);
In this code, we use shared memory to store the submatrix of A and B. We load them into shared memory in each iteration. Becaue we use 16x16 block size and tile size, each thread will load one element of A and B into shared memory. We use __syncthreads() to make sure all threads have loaded the data. Then, in each thread, we calculate the dot product of one row of As and one column of Bs. After that, we use __syncthreads() to make sure all threads have finished the calculation. In next iteration, we load another tile of A and B into shared memory. After all iterations are done, each thread will calculate the dot product of one row of A and one column of B, and accumulate it to t. Finally, we write the result back to C.
By using tiling, we can reduce the number of global memory accesses TILE_DIM times. This can reduce the latency of global memory access and improve performance.
2D block tiling GEMM (multi-element per thread)
In the previous example, I demonstrated how to use 2D block tiling to reduce the number of global memory accesses. For simplicity, I use square tile with size of 16x16, and the thread block size is also 16x16. In this setting, each thread calculates one result of matrix C. For a 4096x4096 matrix, we need launch 65536 thread blocks each with 16x16 threads. If each thread only does a small amount of work, it may not be able to hide the latency of global memory access. Even if warp scheduler can hide the latency by always issue the ready to run warps, the max number of warps per SM is limited. If no warps are ready to run, a stall occurs.
In Nsight compute, according to the scheduler statistics, I can see that the percentage of no eligible is 56.05%, and this means 56.05% of the time, no warp is ready to run and the issue slot is skipped and no instruction is issued.
One solution is to handle more elements per thread. For example, we can use 64x64 tile size and 16x16 thread block size. Then each thread will calculate 16 elements of C.
Here is the code:
template<int TILE_DIM, int BLOCK_DIM_Y, int BLOCK_DIM_X>
__global__ void gemm_block_tiling_v2(const int M, const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C) {
__shared__ float As[TILE_DIM][TILE_DIM+1];
__shared__ float Bs[TILE_DIM][TILE_DIM];
// thread id in thread block
int tid = blockDim.x * threadIdx.y + threadIdx.x;
// the left upper corner of tile in matrix C
int cx = blockIdx.x * TILE_DIM;
int cy = blockIdx.y * TILE_DIM;
constexpr int BLOCK_SIZE = BLOCK_DIM_Y * BLOCK_DIM_X;
constexpr int BATCH_X = TILE_DIM / BLOCK_DIM_X;
constexpr int BATCH_Y = TILE_DIM / BLOCK_DIM_Y;
float accum[BATCH_Y][BATCH_X] = {0.0};
for (int k = 0; k < K; k += TILE_DIM) {
/* load global memory into shared memory, here I use a loop with
* step size of BLOCK_SIZE to handle more elements in a thread. */
for (int i = 0; i < TILE_DIM * TILE_DIM; i += BLOCK_SIZE) {
int y = (i+tid) / TILE_DIM; // index in As and Bs
int x = (i+tid) % TILE_DIM;
As[y][x] = A[(cy + y) * K + k + x];
Bs[y][x] = B[(k + y) * N + cx + x];
}
__syncthreads();
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
for (int ki = 0; ki < TILE_DIM; ki++) {
accum[i][j] += As[y][ki] * Bs[ki][x];
}
}
}
__syncthreads();
}
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
y = cy + y; // index in matrix C
x = cx + x;
C[y * N + x] = alpha * accum[i][j] + beta * C[y * N + x];
}
}
}
We can launch the kernel with:
constexpr int BLOCK_DIM_X = 16;
constexpr int BLOCK_DIM_Y = 16;
constexpr int TILE_DIM = BLOCK_DIM_X * 4;
dim3 threads = {16, 16};
dim3 blocks = {cdiv(N, TILE_DIM), cdiv(M, TILE_DIM)}; // cdiv is a function to calculate ceil(a/b)
gemm_block_tiling_v2<TILE_DIM, BLOCK_DIM_Y, BLOCK_DIM_X><<<blocks, threads>>>(M, N, K, alpha, A, B, beta, C);
The thread block size is 16x16, and the tile size is 64x64. each thread has to handle 16 elements of C.
When we load A and B into shared memory As and Bs, we use a loop with step size of BLOCK_SIZE. In first iteration, all threads load data into As[:256] and Bs[:256], and in second iteration, all threads load data into As[256:512] and Bs[256:512] and so on.
Since thread block size is less than the tile size, the BATCH_Y and BATCH_X represent how many iterations a thread block should take in the row and column direction.
When we store the result back to matrix C, we can calculate the index of C by cy + y and cx + x. cy and cx are the left upper corner of tile C in matrix C. y and x are the index in tile C.
2D block tiling GEMM (general)
In the previous example, we used a predefined tile size and block size. It is helpful for understanding the algorithm, but not suitable for all cases. In this section, we make the tile size and block size configurable.
The tile size of matrix A is Bm x Bk and the tile size of matrix B is Bk x Bn. The block size is BLOCK_SIZE. All those sizes can be configured through template parameters.
The kernel implementation is similar to the previous example, but we should consider the shape of As and Bs.
template<int Bm, int Bn, int Bk, int BLOCK_DIM_Y, int BLOCK_DIM_X>
__global__ void gemm_block_tiling_v3(const int M, const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C) {
__shared__ float As[Bm][Bk];
__shared__ float Bs[Bk][Bn];
// thread id in thread block
int tid = blockDim.x * threadIdx.y + threadIdx.x;
// the left upper corner of tile C in matrix C
const unsigned int cx = blockIdx.x * Bn;
const unsigned int cy = blockIdx.y * Bm;
constexpr int BLOCK_SIZE = BLOCK_DIM_Y * BLOCK_DIM_X;
constexpr int BATCH_X = Bn / BLOCK_DIM_X;
constexpr int BATCH_Y = Bm / BLOCK_DIM_Y;
float accum[BATCH_Y][BATCH_X] = {0.0};
for (int k = 0; k < K; k += Bk) {
for (int i = 0; i < Bm * Bk; i += BLOCK_SIZE) {
const int idx = i + tid;
int y = idx / Bk; // index in As
int x = idx % Bk;
As[y][x] = A[(cy + y) * K + k + x];
}
for (int i = 0; i < Bk * Bn; i += BLOCK_SIZE) {
const int idx = i + tid;
int y = idx / Bn; // index in Bs
int x = idx % Bn;
Bs[y][x] = B[(k + y) * N + cx + x];
}
__syncthreads();
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
for (int ki = 0; ki < Bk; ki++) {
accum[i][j] += As[y][ki] * Bs[ki][x];
}
}
}
__syncthreads();
}
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
y = cy + y; // index in matrix C
x = cx + x;
C[y * N + x] = alpha * accum[i][j] + beta * C[y * N + x];
}
}
}
In the above code, we can can set the shape of the tiles of A and B. It don’t need to be a square now. The tile of A is Bm x Bk and the tile of B is Bk x Bn. The tile of C is Bm x Bn. Each thread block computes one tile of C.
We can launch the kernel function with the following code:
constexpr int BLOCK_DIM_X = 16;
constexpr int BLOCK_DIM_Y = 16;
constexpr int Bm = BLOCK_DIM_X * 4;
constexpr int Bn = BLOCK_DIM_X * 4;
constexpr int Bk = 16;
dim3 threads = {16, 16};
dim3 blocks = {cdiv(N, Bn), cdiv(M, Bm)};
gemm_block_tiling_v3<Bm, Bn, Bk, BLOCK_DIM_Y, BLOCK_DIM_X><<<blocks, threads>>>(M, N, K, alpha, A, B, beta, C);
Since thread block size is less than the size of matrix tile of C, each thread computes Bm * Bn / BLOCK_SIZE elements of C. The following figure shows the whole process. With Bm = 128, Bn = 128 and BLOCK_SIZE = 256, the thread block iterates over the tile of C (Ct), and each thread computes 64 elements, which saved in array accum. All those 64 elements belong to 64 sub-tiles of Ct.
unroll
In order to maximize the instruction level parallelism (ILP) of GPU, I use #pragma unroll to unroll the loops. The following code shows how we load tiles of A and B into shared memory:
#pragma unroll
for (int i = 0; i < Bm * Bk; i += BLOCK_SIZE) {
const int idx = i + tid;
int y = idx / Bk; // index in As
int x = idx % Bk;
As[y][x] = A[(cy + y) * K + k + x];
}
The for loop has initial value i = 0, but in the beginning of each iteration, i is added by tid, which is the thread index. Why don’t we use tid as the initial value, and write the for loop as follows?
#pragma unroll
for (int i = tid; i < Bm * Bk; i += BLOCK_SIZE) {
int y = i / Bk; // index in As
int x = i % Bk;
As[y][x] = A[(cy + y) * K + k + x];
}
In every thread, tid is different. If the initial value of i is tid, compiler cannot unroll the loop because the compiler cannot determine how many iterations are needed. If the loop cannot be unrolled, #pragma unroll will be ignored. If the loop can be unrolled, and the compiler thinks unrolling is positive for the performance, it will unroll the loop even you did’t add #pragma unroll in the code.
In CUDA programming, loop unrolling is a very helpful optimization technique. It can improve the performance dramatically if you use it properly. I test the above code with and without loop unrolling. The result shows that loop unrolling can improve the performance by nearly 10 times.
2D block tiling GEMM (outer product)
After loading tiles of A and B into shared memory, we need to calculate the dot product between each row of As and column of Bs. We can use three nested loops for this purpose:
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
for (int ki = 0; ki < Bk; ki++) {
accum[i][j] += As[y][ki] * Bs[ki][x];
}
}
}
Using the setting BATCH_Y = 8, BATCH_X = 8, and Bk = 8, in order to compute the dot product, each thread performs BATCH_Y * BATCH_X * Bk = 8 * 8 * 8 = 512 times multiply-add operations, and it needs 2 x 512 = 1024 shared memory accesses.
If we change the order of loops, and move the innermost loop to the outermost position, each step of loop over ki, the column of As and row of Bs are ki, and this means we read a column of As and a row of Bs.
for (int ki = 0; ki < Bk; ki++) {
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
accum[i][j] += As[y][ki] * Bs[ki][x];
}
}
}
Each step of ki, we can read the elements of As and Bs into registers, and then we can do multiply-add operations using these registers. In this way, we can reduce the number of shared memory accesses, and the register is much faster than shared memory.
We need two arrays to save the ki-th elements of rows and columns of As and Bs.
float a[BATCH_Y];
float b[BATCH_X];
When we do the dot product, we read elements from As and Bs into a and b. Then we do the multiply-add operation using registers. The code is like this:
for (int p = 0; p < Bk; p++) {
for (int i = 0; i < BATCH_Y; i++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
a[i] = As[y][p];
}
for (int i = 0; i < BATCH_X; i++) {
int x = threadIdx.x + i * BLOCK_DIM_X;
b[i] = Bs[p][x];
}
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
accum[i][j] += a[i] * b[j];
}
}
}
In this way, we can reduce the number of shared memory accesses from 1024 to Bk * (BATCH_Y + BATCH_X) = 8 * (8 + 8) = 128.
Actually, in linear algebra, this is called “outer product”. Each iteration of ki, we read a column of As and a row of Bs, and then we do outer product and form a matrix. Adding all there matrices, we get the result matrix.
Here is the kernel function:
template<int Bm, int Bn, int Bk, int BLOCK_DIM_Y, int BLOCK_DIM_X>
__global__ void gemm_block_tiling_v4(const int M, const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C) {
// shared memory used to store the tiles of A and B
__shared__ float As[Bm][Bk];
__shared__ float Bs[Bk][Bn];
// thread id in thread block
int tid = blockDim.x * threadIdx.y + threadIdx.x;
// the left upper corner of tile C in matrix C
const unsigned int cx = blockIdx.x * Bn;
const unsigned int cy = blockIdx.y * Bm;
constexpr int BLOCK_SIZE = BLOCK_DIM_Y * BLOCK_DIM_X;
constexpr int BATCH_X = Bn / BLOCK_DIM_X;
constexpr int BATCH_Y = Bm / BLOCK_DIM_Y;
float accum[BATCH_Y][BATCH_X] = {0.0};
float a[BATCH_Y];
float b[BATCH_X];
for (int k = 0; k < K; k += Bk) {
for (int i = 0; i < Bm * Bk; i += BLOCK_SIZE) {
const int idx = i + tid;
int y = idx / Bk; // index in As
int x = idx % Bk;
As[y][x] = A[(cy + y) * K + k + x];
}
for (int i = 0; i < Bk * Bn; i += BLOCK_SIZE) {
const int idx = i + tid;
int y = idx / Bn; // index in Bs
int x = idx % Bn;
Bs[y][x] = B[(k + y) * N + cx + x];
}
__syncthreads();
for (int p = 0; p < Bk; p++) {
for (int i = 0; i < BATCH_Y; i++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
a[i] = As[y][p];
}
for (int i = 0; i < BATCH_X; i++) {
int x = threadIdx.x + i * BLOCK_DIM_X;
b[i] = Bs[p][x];
}
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
accum[i][j] += a[i] * b[j];
}
}
}
__syncthreads();
}
for (int i = 0; i < BATCH_Y; i++) {
for (int j = 0; j < BATCH_X; j++) {
int y = threadIdx.y + i * BLOCK_DIM_Y;
int x = threadIdx.x + j * BLOCK_DIM_X;
y = cy + y; // index in matrix C
x = cx + x;
C[y * N + x] = alpha * accum[i][j] + beta * C[y * N + x];
}
}
}
Performance
Here is the performance of these kernels, compared with cuBLAS:
| kernel | time(ms) | performance |
|---|---|---|
| cuBLAS | 7.19 | 100% |
| gemm_naive | 77.63 | 9.26% |
| gemm_block_tiling_v1 | 52.19 | 13.78% |
| gemm_block_tiling_v2 | 32.10 | 22.40% |
| gemm_block_tiling_v3 | 13.17 | 54.71% |
| gemm_block_tiling_v4 | 11.43 | 62.90% |
The performance of the final kernel still has a large gap between cuBLAS on my machine.Other commonly used optimization techniques, like vectorization and data prefetching, are ignored in this post, since they increase the complexity of code and are not easy to understand. I help this post can be helpful for you to understand the basic idea of matrix multiplication.
Modern GPUs has a lot of hardware resources can be used to further improve the performance, like Tensors Cores can be used to do matrix-matrix multiplication directly, TMA(Tensor Memory Accelerator) can be used to load data from global memory to shared memory, and so on. If we want further improve the performance on modern GPUs, we need to use these hardware resources of GPU.
Summary
Now we have understood how to do matrix multiplication in CUDA. Let’s review the matrix multiplication algorithm from the mathematical perspective.
Here is the naive matrix multiplication algorithm:
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
float t = 0.0f;
for (int k = 0; k < K; k ++) {
t += A[i][k] * B[k][j];
}
C[i][j] = t;
}
}
This is how we do matrix multiplication in the first year of college. We do inner product of a row for A and a column from B, then we store the result in matrix C. In this algorithm, for each element in matrix C, we need read a row from matrix A and a column from matrix B. The total number of memory read is 2 * M * N * K, and the total number of memory write is M * N.
We can do matrix multiplication in another way:
for (int k = 0; k < K; k ++) {
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
In this code, the K dimension is the outer most loop. In each iteration, we calculate a out product of a column of matrix A and a row of matrix B. Each out product produces a partial result of matrix C, after all K iterations, we can get the final result of matrix C. In this algorithm, matrix A and matrix B only need to be read once, and the total number of memory read is K (M+N). The total number of memory write is K * M * N.
If matrix is stored in slow memory, we perfer the second algorithm. But we need a fast memory to store all the M-by-N elements of matrix C for fast update. Unfourtunatly, the fast memory size is always very limited. Considering the submatrix multiplication, we can partition matrix A and B into submatrices, then we can store the partial result of each submatrix multiplication in fast memory. This leads to another algorithm:
for (int m = 0; m < M; m += Mtile) // iterate over M dimension
for (int n = 0; n < N; n += Ntile) // iterate over N dimension
for (int k = 0; k < K; k++)
for (int i = 0; i < Mtile; i++) // compute one tile
for (int j = 0; j < Ntile; j++) {
int row = m + i;
int col = n + j;
C[row][col] += A[row][k] * B[k][col];
}
}
}
}
}
This is the basic idea behind the block tiling kernel we implmented in this post. For each tile of C, tiles of A and B are loaded only once. The tile of C is small enough to fit in fast memory. In this way, we can reduce the number of memory read and write, and increase the arithmetic intensity.
If you can understand this algorithm, you can easily understand the kernel implemented in this post. By the way, this is also the basic idea of CUTLASS library.
Understanding the maxtrix multiplication algoritm in CUDA is not easy. You can find many blog posts on the internet, But you will soon confused by the untidy code and fanc optimization methods. When I tried to understand this algorithm, I found reading other’s code is not helpful. When you tried to understand other’s code, you probably lost in a large amount of weird variable names and pointer offset. The best way is to write your own code, start from the simplest algorithm and then optimize it.
You can find the source code of the kernel in this post on Github.