CUDA 010 - Occupancy
In this article, I will explore the concept of occupancy in CUDA programming, explain what occupancy is, why it exists, how to calculate it, and techniques to optimize it for better performance.
What is Occupancy?
Each Streaming Multiprocessor (SM) on a GPU can execute multiple warps concurrently. Once a warp is blocked, for example, waiting for data from global memory, the SM can switch to another warp that is ready to execute. This way, the SM can keep its compute units busy and hide memory latency. If only a small amount of warps are active on the SM, when one warp is waiting for data, the SM may have nothing to do, leading to underutilization of the GPU.
Occupancy is the ratio of active warps on the SM to the maximum number of warps that could be active on the SM. If occupancy is 100%, it means the SM has the maximum number of active warps, which can help hide memory latency effectively. If occupancy is low, the SM may not have enough warps to switch to when one warp is waiting for data, leading to idle compute units. But low occupancy does not always mean low performance. If a kernel does not have high memory latency or has high instruction-level parallelism (ILP), it may still perform well even with low occupancy.
Factors that Limit Occupancy
GPU hardware has some limitations on the number of blocks and threads per SM. You can query these limits using the cudaDeviceGetAttribute function. For example:
#include <iostream>
void device_query() {
int device;
cudaDeviceProp prop;
cudaGetDevice(&device);
cudaGetDeviceProperties(&prop, device);
std::cout << prop.name << std::endl;
std::cout << "SMs count: " << prop.multiProcessorCount << std::endl;
std::cout << "max blocks per sm: " << prop.maxBlocksPerMultiProcessor << std::endl;
std::cout << "max threads per sm: " << prop.maxThreadsPerMultiProcessor << std::endl;
std::cout << "max threads per block: " << prop.maxThreadsPerBlock << std::endl;
std::cout << "max warp per sm: " << prop.maxThreadsPerMultiProcessor / prop.warpSize << std::endl;
std::cout << "shared memory per sm: " << prop.sharedMemPerMultiprocessor << std::endl;
std::cout << "shared memory per block: " << prop.sharedMemPerBlock << std::endl;
std::cout << "registers per sm: " << prop.regsPerMultiprocessor << std::endl;
std::cout << "registers per block: " << prop.regsPerBlock << std::endl;
}
The output on my machine is:
NVIDIA GeForce RTX 5070
SMs count: 48
max blocks per sm: 24
max threads per sm: 1536
max threads per block: 1024
max warp per sm: 48
shared memory per sm: 102400
shared memory per block: 49152
registers per sm: 65536
registers per block: 65536
I will use these values to explain how to maximize occupancy in the following sections.
A SM can have at most 48 active warps. Although each SM can only run a limited number of warps at the same time, it can schedule more warps to hide memory latency. If a SM only has one active warps, when the warp is waiting for data from memory, the SM has nothing to do. But if there are multiple active warps, when one warp is waiting for data, the SM can switch to another warp that is ready to execute. This way, the SM can keep its compute units busy and hide memory latency. So having more active warps can help hide even longer memory latency.
If you want to launch at least 48 warps per SM, you need at least 1536 threads per SM (48 warps * 32 threads per warp). Since each block can have at most 1024 threads, you need to launch multiple blocks per SM. But a SM can have at most 24 blocks, so a block should have at least 64 threads (1536 / 24 = 64).
The number of registers and shared memory used per block also affect the number of blocks that can be active on a SM. If you use too many registers or too much shared memory per block, the SM can not fit as many blocks. The worst case is that only one block can be active on the SM, which means only a maximum of 32 warps (1024 threads) can be active.
Let’s do some calculations to get a better understanding of how block size, shared memory usage, and register usage affect occupancy.
If the block size is 256 threads, you have to launch at least (1536(max threads per sm) / 256) = 6 blocks. Each block can only use 1/6 of the shared memory, that is 8192 bytes (49152 / 6). Each thread can use at most 42 registers (65536 / (6 * 256)).
If the kernel use zero shared memory and nearly zero registers, you can launch at most 24 blocks per SM. To achieve maximum occupancy, you need to launch at least 1536 threads per SM. So the block size should be at least 64 threads (1536 / 24).
Calculating Occupancy
Calculating occupancy is very easy. Firstly, you need to know the maximum number of active warps per SM. Secondly, you need to know the actual number of active warps per SM for your kernel. When the block size is determined, only thing left is to calculate how many blocks can be active on a SM. Here is a simple example to calculate occupancy of a kernel:
void calculate_occupancy(const void* kernel, const int block_size) {
int blocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks, kernel, block_size, 0);
int device;
cudaDeviceProp prop{};
cudaGetDevice(&device);
cudaGetDeviceProperties(&prop, device);
int max_warps = prop.maxThreadsPerMultiProcessor / prop.warpSize;
int active_warps = blocks * block_size / prop.warpSize;
double occupancy = static_cast<double>(active_warps) / max_warps * 100;
std::cout << "block_size:" << block_size
<<" blocks:" << blocks
<< " occupancy: "<< occupancy << "%"
<< std::endl;
}
I run this function for for three different kernel with different block sizes:
- Kernel 1: zero shared memory, nearly zero registers
- Kernel 2: 16KB bytes shared memory, nearly zero registers
- Kernel 3: zero shared memory, 128 registers per thread, together with some default registers used by the compiler, it’s about 140+ registers per thread.
| Kernel | Block Size | Blocks Per SM | Occupancy |
|---|---|---|---|
| 1 | 64 | 24 | 100% |
| 1 | 256 | 6 | 100% |
| 1 | 512 | 3 | 100% |
| 1 | 1024 | 1 | 66.66% |
| - | |||
| 2 | 64 | 5 | 20.83% |
| 2 | 256 | 5 | 83.33% |
| 2 | 512 | 3 | 100% |
| 2 | 1024 | 1 | 66.66% |
| - | |||
| 3 | 64 | 6 | 25% |
| 3 | 256 | 1 | 16.67% |
| 3 | 512 | 0 | 0% |
| 3 | 1024 | 0 | 0% |
For kernel 1, the number of blocks per SM is determined by the block size. When block size is 1024, only one block can be active per SM, so the occupancy is 66.66%.
For kernel 2, a block uses 16KB of shared memory, the maximum amount of shared memory per SM is 100KB (shared memory and L1 cache share the same memory space. By default, the maximum shared memory per SM is less than 100KB). In this case, only 5 blocks can be active on a SM, and if the block size is small (64), the occupancy is very low (20.83%). When the block size is 512, the occupancy reaches 100%.
And for kernel 3, register usage is a limiting factor for blocks per SM. When the block size is 64, only 6 blocks can be active on a SM. When the block size increases to 512, registers are not enough for even one block to be active on a SM.
Here is a more complete table showing how different register usage and block size affect occupancy:
| Registers | Block Size | Blocks Per SM | Occupancy |
|---|---|---|---|
| 16 | 64 | 24 | 100% |
| 16 | 256 | 6 | 100% |
| 16 | 512 | 3 | 100% |
| 16 | 1024 | 1 | 66.66% |
| - | |||
| 32 | 64 | 24 | 100% |
| 32 | 256 | 6 | 100% |
| 32 | 512 | 3 | 100% |
| 32 | 1024 | 1 | 66.66% |
| - | |||
| 64 | 64 | 14 | 58.33% |
| 64 | 256 | 3 | 50% |
| 64 | 512 | 1 | 33.33% |
| 64 | 1024 | 0 | 0% |
| - | |||
| 128 | 64 | 6 | 25% |
| 128 | 256 | 1 | 16.67% |
| 128 | 512 | 0 | 0% |
| 128 | 1024 | 0 | 0% |
Techniques for Occupancy Tuning
Launching more threads
Launching more threads can increase the number of active warps, and every SM can have more work to do.
Using smaller block size
SMs have a limit on the number of threads, for example, 1536 threads per SM on my machine. If the block size is 1024, only one block can be active on the SM, then only 1024 threads can be active on the SM. The optimal block size can vary by kernel. The key is to balance block size against resource usage. Normally, 256 or 512 threads per block are good choices.
Reducing shared memory usage
Shared memory usage can limit the number of blocks per SM. Reducing shared memory usage per block can increase the maximum number of blocks per SM.
Reducing per-thread register usage
Register is very limited on GPU, as I have explained before, If you set block size to 256, if you want reach 100% occupancy, you can only use 30~40 registers per thread.
Using too many registers can hurt performance either. A thread can use up to 255 registers, if you use more than 255 registers, the compiler will spill some of them to local memory. This can hurt performance because accessing local memory is slower than register access.
Summary
Achieving 100% occupancy is not absolutely necessary for optimal performance. If your kernel has high ILP or low memory latency, the warp scheduler can have enough work to do even with lower occupancy. However, low occupancy can limit the ability to hide memory latency, so it is generally a good idea to aim for higher occupancy when possible.
The block size, shared memory usage, and register usage are the main factors that affect occupancy. By carefully choosing these parameters, the basic idea is having enough active warps per SM to maximize occupancy. The number of active warps is determined by the number of blocks per SM and the block size. The number of blocks per SM is limited by shared memory usage, and register usage.