CUDA 009 - Softmax
What is Softmax?
The softmax operation is widely used in deep learning. It is a function that takes as input a vector of arbitrary real numbers, and normalizes it into a probability distribution consisting entirely of positive numbers which also sums to 1. It often used in the output layer of a neural network to convert raw scores (logits) into probabilities.
Here is the definition of softmax:
Here is an example of softmax on a vector containing 3 elements:
If the input value is relatively large, the exponential function can cause overflow. To avoid this, we can subtract the maximum value of the vector from each element before applying softmax. This way, all values are shifted to be a negative number, and the values of the exponential function will be within 0 and 1. This is called the safe softmax. The updated formula is:
In real situation, we always use the safe softmax. In this post, we will implement the safe softmax in CUDA.
The Basic Idea
According to the definition of softmax, it can be calculated in three passes over the input array:
- find the maximum value of the input vector.
- apply the exponential function, sum up all the exponentials.
- divide each element by the sum of all exponentials.
The input of softmax is normally a matrix with shape , where is the batch size or rows, and is the dimension of each row. We should do softmax independently on each row. we can implement it in C++ as follows:
void softmax(const float *in, float *out, int rows, int dim) {
for (int row = 0; row < rows; ++row) {
const float *x = &in[row * dim];
float *o = &out[row * dim];
float max = std::numeric_limits<float>::min();
for (int i = 0; i < dim; i++) {
max = std::max(max, x[i]);
}
float sum = 0;
for (int i = 0; i < dim; i++) {
sum += std::exp(x[i] - max);
}
for (int i = 0; i < dim; i++) {
o[i] = std::exp(x[i] - max) / sum;
}
}
}
Now, let’s see how we can implement this in CUDA.
Warp Reduction Softmax
Computing softmax involves two reduce operations: one to find the maximum value and another to sum up all exponentials. If we can handle a row in a warp, we can use the warp-level reduction operations provided by CUDA.
Warp-Level Reduction
CUDA provides some built-in warp shuffle functions which can exchange a variable between threads within a warp. This is useful when you want to pass data between threads within a warp. This data is transfered from register to register, which is very fast.
These functions are:
T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
T __shfl_up_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
T __shfl_down_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);
For the details of these functions, you can refer to the official NVIDIA CUDA documentation. Here, we focus on the __shfl_xor_sync function which can copy a variable from another thread within the same warp based on bitwise XOR of own lane ID (threads within a warp are referred to as lanes).
Here’s an example of how you might implement the maximum value reduction within a warp block:
__global__ void kernel() {
float max = max_of_this_thread;
for (size_t mask = 16; mask > 0; mask /= 2) {
max = fmax(__shfl_xor_sync(0xffffffff, max, mask), max);
}
int max_in_warp = max;
}
When you call __shfl_xor_sync, you need to specify the mask as 0xffffffff which means all threads in a warp are involved in this operation. For the first iteration, we use mask = 16 to call __shfl_xor_sync. In thread with lane 0, since 0 ^ 16 = 16, it will copy the value from thread 16. Similarly, in thread with lane 16, since 16 ^ 16 = 0, it will copy the value from thread 0. For the second iteration, the mask is 8, since 0 ^ 8 = 8 and 8 ^ 8 = 0, threads 0 will copy from thread 8, and thread 8 will copy from thread 0.
It must be hard to understand how the values are exchanged. One way to understand this is by running the kernel with a small number of threads (e.g., 32) and printing intermediate values after each iteration.
I run this kernel with 32 threads and print the intermediate values in a warp.
__global__ void kernel() {
int lane = threadIdx.x % 32;
for (size_t mask = 16; mask > 0; mask /= 2) {
int val = __shfl_xor_sync(0xffffffff, lane, mask);
// print val here
}
}
I use lane ID as value to call __shfl_xor_sync. Here is the output:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
--------------------------------------------------------------------------------------------------------
mask=16: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
mask=8: 8 9 10 11 12 13 14 15 0 1 2 3 4 5 6 7 24 25 26 27 28 29 30 31 16 17 18 19 20 21 22 23
mask=4: 4 5 6 7 0 1 2 3 12 13 14 15 8 9 10 11 20 21 22 23 16 17 18 19 28 29 30 31 24 25 26 27
mask=2: 2 3 0 1 6 7 4 5 10 11 8 9 14 15 12 13 18 19 16 17 22 23 20 21 26 27 24 25 30 31 28 29
mask=1: 1 0 3 2 5 4 7 6 9 8 11 10 13 12 15 14 17 16 19 18 21 20 23 22 25 24 27 26 29 28 31 30
With mask equals to 16, the first half of the warp exchanges values with the second half. With mask equals to 8, the first half within the former first half exchanges values with the second half, and so on.
I use __shfl_xor_sync__ to find the maximum value in a warp, and I print the intermediate values.
__global__ void kernel() {
int lane = threadIdx.x % 32;
float max = lane;
for (size_t mask = 16; mask > 0; mask /= 2) {
max = fmax(__shfl_xor_sync(0xffffffff, max, mask), max);
// print max here
}
}
Here is the output:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
--------------------------------------------------------------------------------------------------------
mask=16: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
mask=8: 24 25 26 27 28 29 30 31 24 25 26 27 28 29 30 31 24 25 26 27 28 29 30 31 24 25 26 27 28 29 30 31
mask=4: 28 29 30 31 28 29 30 31 28 29 30 31 28 29 30 31 28 29 30 31 28 29 30 31 28 29 30 31 28 29 30 31
mask=2: 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31 30 31
mask=1: 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31
We can see that after shuffling, threads in a warp have the same maximum value.
These warp-level shuffle intrinsics are very useful for reducing data within a warp. They don’t use shared memory or global memory, the data is moved only within the registers, which makes it very fast.
We can wrap the __shfl_xor_sync in a function to make it easier to use. Here is the code:
__device__ __forceinline__ float warp_reduce_max(float val) {
#pragma unroll
for (size_t mask = 16; mask > 0; mask /= 2) {
val = fmax(__shfl_xor_sync(0xffffffff, val, mask), val);
}
return val;
}
__device__ __forceinline__ float warp_reduce_sum(float val) {
#pragma unroll
for (size_t mask = 16; mask > 0; mask /= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
The overall pattern is similar, the only difference is how to use the result of __shfl_xor_sync. You can easily implement min, sum and other reductions using this pattern.
Warp Reduction Softmax Implementation
Once we can do warp-level reductions, we can implement the softmax kernel like this:
__global__ void softmax_warp_reduce(const float *in, float *out, int rows, int dim) {
auto threads = gridDim.x * blockDim.x;
auto warps = threads / 32;
auto global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto warp_idx = global_thread_idx / 32;
auto lane = global_thread_idx % 32;
for (unsigned int row = warp_idx; row < rows; row += warps) {
const float *x = &in[row * dim];
float *o = &out[row * dim];
float max = cuda::std::numeric_limits<float>::min();
for (auto col = lane; col < dim; col += 32) {
max = fmax(x[col], max);
}
max = warp_reduce_max(max);
float sum = 0;
for (auto col = lane; col < dim; col += 32) {
sum += exp(x[col] - max);
}
sum = warp_reduce_sum(sum);
for (auto col = lane; col < dim; col += 32) {
o[col] = exp(x[col] - max) / sum;
}
}
}
In this code, each thread warp handle a row of the input matrix. The warp_reduce_max and warp_reduce_sum functions use the __shfl_xor_sync() function to perform reductions within a warp.
Block Reduction Softmax
If the dimension of the input vector is large, processing the entire row in a single warp might not be efficient. In such cases, we can handle a row using a thread block.
Block-Level Reduction
We can first compute the maximun within warps, and save the warp-level maximum in shared memory. Then, we can perform a reduction for the block-level maximum.
Since a thread block can have at most 1024 threads, that is 32 warps. We can allocate 32 elements in shared memory for each warp. One of the warps can be used to read there 32 warp-level maximums and perform another warp level reduction. Then write the block-wide maximum back to shared memory, at smem[0]. Then, all threads can read this block-wide maximum from smem[0]. This is the idea of block-level reduction.
We can do block level max this like this:
__device__ float block_reduce_max(float *smem, float val) {
auto warps = blockDim.x * blockDim.y / 32;
val = warp_reduce_max(val);
if (threadIdx.x % 32 == 0) {
smem[threadIdx.x / 32] = val;
}
__syncthreads();
if (threadIdx.x < 32) {
if (threadIdx.x < warps) {
val = smem[threadIdx.x];
} else {
val = cuda::std::numeric_limits<float>::min();
}
__syncwarp();
val = warp_reduce_max(val);
smem[0] = val;
}
__syncthreads();
return smem[0];
}
Firstly, we perform a warp-level reduction, and save the result in smem. This is done by the first thread of each warp (threadIdx.x % 32 == 0), and the index is warp index (threadIdx.x / 32). Then, the first 32 threads of the block (i.e., threadIdx.x < 32) read up these 32 warp-level maximums and perform another reduction. (The number of warps may be less than 32, we need to handle this case). Finally, all threads read the block-wide maximum from smem[0] and return it.
Similarly, we can implement a block-level sum reduction by replacing warp_reduce_max with warp_reduce_sum and handle to threads.x > warps case (set val as 0).
Block-Level Softmax Implementation
With the support of block-level reduction, we can implement the softmax for a row in a thread block like this:
__global__ void softmax_block_reduce(const float *in, float *out, int rows, int dim) {
__shared__ float smem[32];
for (unsigned int row = blockIdx.x; row < rows; row += gridDim.x) {
const float *x = &in[row * dim];
float *o = &out[row * dim];
float max = cuda::std::numeric_limits<float>::min();
for (auto col = threadIdx.x; col < dim; col += blockDim.x) {
max = fmax(x[col], max);
}
max = shared_mem_reduce_max(smem, max);
float sum = 0;
for (auto col = threadIdx.x; col < dim; col += blockDim.x) {
sum += std::exp(x[col] - max);
}
sum = shared_mem_reduce_sum(smem, sum);
for (auto col = threadIdx.x; col < dim; col += blockDim.x) {
o[col] = std::exp(x[col] - max) / sum;
}
}
}
Online Softmax
The above implementations involve three passes over the input data, it increases the global memory accesses and may not be optimal for large input. There is a technique called “online softmax” which reduces the number of passes over the input data from three to two.
How Online Softmax Works
This is the three loops of safe softmax:
float max = std::numeric_limits<float>::min();
for (int i = 0; i < dim; i++) {
max = std::max(max, x[i]);
}
float sum = 0;
for (int i = 0; i < dim; i++) {
sum += std::exp(x[i] - max);
}
for (int i = 0; i < dim; i++) {
o[i] = std::exp(x[i] - max) / sum;
}
The second loop depends on the max value computed in the first loop. online softmax can fuse the first and second loops together.
Let’s look an example of how to get the max value and sum in one pass:
float sum = 0;
float max = std::numeric_limits<float>::min();
for (int i = 0; i < dim; i++) {
float m = std::max(x[i], max);
if (i > 0 && m > max) {
sum = sum * std::exp(max - m);
}
max = m;
sum += std::exp(x[i] - max);
}
It is very easy to understand how it works. We always keep the maximum value and adjust the sum accordingly. During the iteration, we always use the current maximum value as the subtractor for the exponential function. When we meet a new maximum value m, we adjust the sum by multiplying it with std::exp(max - m).
During the iteration, the value of sum is , now we have a new maximum value m, we need to adjust the sum to . We can first multiply the current sum with , and the result is . Then we multiply it with to get .
Here is the math derivation:
Using this method, we can get the maximum value and sum in one pass and then compute the softmax in a second pass.
Kernel Implementation
Here is the kernel implementation of online softmax with block reduction.
__global__ void online_softmax_block_reduce(const float * __restrict__ in, float *out, int rows, int dim) {
__shared__ float smem[32];
for (unsigned int row = blockIdx.x; row < rows; row += gridDim.x) {
const float *x = &in[row * dim];
float *o = &out[row * dim];
float sum = 0;
float max = cuda::std::numeric_limits<float>::min();
for (auto col = threadIdx.x; col < dim; col += blockDim.x) {
float m = fmax(x[col], max);
if (col > threadIdx.x && m > max) {
sum = sum * __expf(max - m);
}
max = m;
sum += __expf(x[col] - max);
}
float m = shared_mem_reduce_max(smem, max);
if (m > max) {
sum = sum * __expf(max - m);
max = m;
}
sum = shared_mem_reduce_sum(smem, sum);
for (auto col = threadIdx.x; col < dim; col += blockDim.x) {
o[col] = __expf(x[col] - max) / sum;
}
}
}
In the first loop, we compute the thread-local maximum and sum. Then we use shared memory to reduce the maximum value across all threads in a block, and adjust the sum and max if necessary. Finally, we compute the softmax values for each element using the global max and sum.
Performance
We can compare the performance of our custom softmax kernel with PyTorch’s.
You should download libtorch and configure your cmake to use it. The following cmake configuration links the libtorch library and Python3 and sets up the include directories.
set(Torch_DIR ~/Downloads/libtorch-shared-with-deps-2.9.1+cu130/libtorch/share/cmake/Torch)
find_package(Torch REQUIRED)
find_package(Python3 REQUIRED Interpreter Development)
include_directories(${Python3_INCLUDE_DIRS})
include_directories(${TORCH_INCLUDE_DIRS})
link_libraries(${Python3_LIBRARIES})
link_libraries(${TORCH_LIBRARIES})
You can write a launch function in a .cu file and using pybind11 to create a Python module.
// wrap.cu
#include <torch/types.h>
#include <torch/extension.h>
#include "softmax.cuh"
torch::Tensor softmax(torch::Tensor in) {
const auto rows = in.size(0);
const auto dim = in.size(1);
auto out = torch::empty_like(in);
if (dim <= 1024) {
dim3 threads = 256;
dim3 blocks = rows / (threads.x / 32);
softmax_warp_reduce<<<blocks, threads>>>(in.data_ptr<float>(), out.data_ptr<float>(), rows, dim);
} else {
dim3 threads = 256;
dim3 blocks = rows;
softmax_block_reduce<<<blocks, threads>>>(in.data_ptr<float>(), out.data_ptr<float>(), rows, dim);
}
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("softmax", torch::wrap_pybind_function(softmax), "softmax");
}
In the launch function, we check if the dimension is less than 1024 and use warp-level reduction. Otherwise, we use block-level reduction.
Then, you can load this module in Python as follows:
from torch.utils.cpp_extension import load
# Load the CUDA kernel as a python module
my_softmax = load(name='my_softmax',
sources=['softmax.cu', 'wrap.cu'],
extra_cuda_cflags=[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
],
extra_cflags=['-std=c++17'])
logists = torch.rand((4096, 2048)).cuda()
out = my_softmax.softmax(logists)
Then, we can profile the performance using PyTorch’s profiler module:
import torch
logists = torch.rand((4096, 4096)).cuda()
print("Profiling torch.softmax")
with torch.profiler.profile() as prof:
out1 = logists.softmax(-1)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print("Profiling my_softmax")
with torch.profiler.profile() as prof:
out2 = my_softmax.softmax(logists)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
assert torch.allclose(out1, out2)
Here is the output of profiling:
Profiling torch.softmax
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::softmax 0.17% 18.179us 96.61% 10.501ms 10.501ms 0.000us 0.00% 2.247ms 2.247ms 1
aten::_softmax 0.66% 71.319us 96.44% 10.483ms 10.483ms 449.454us 100.00% 2.247ms 2.247ms 1
cudaLaunchKernel 0.27% 29.353us 85.06% 9.246ms 9.246ms 0.000us 0.00% 1.348ms 1.348ms 1
Runtime Triggered Module Loading 84.66% 9.202ms 84.66% 9.202ms 4.601ms 898.908us 200.00% 898.908us 449.454us 2
Activity Buffer Request 9.74% 1.059ms 9.74% 1.059ms 1.059ms 449.454us 100.00% 449.454us 449.454us 1
Lazy Function Loading 0.13% 14.574us 0.13% 14.574us 14.574us 449.454us 100.00% 449.454us 449.454us 1
void at::native::(anonymous namespace)::cunn_SoftMax... 0.00% 0.000us 0.00% 0.000us 0.000us 449.454us 100.00% 449.454us 449.454us 1
cudaStreamIsCapturing 0.03% 3.341us 0.03% 3.341us 3.341us 0.000us 0.00% 0.000us 0.000us 1
cudaMalloc 0.95% 103.488us 0.95% 103.488us 103.488us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 3.39% 368.524us 3.39% 368.524us 368.524us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 10.869ms
Self CUDA time total: 449.454us
Profiling my_softmax
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
online_softmax_block_reduce(float const*, float*, in... 0.00% 0.000us 0.00% 0.000us 0.000us 455.340us 100.00% 455.340us 455.340us 1
aten::empty_like 0.43% 9.822us 75.45% 1.713ms 1.713ms 0.000us 0.00% 0.000us 0.000us 1
aten::empty_strided 1.57% 35.617us 75.02% 1.703ms 1.703ms 0.000us 0.00% 0.000us 0.000us 1
Activity Buffer Request 67.15% 1.525ms 67.15% 1.525ms 1.525ms 0.000us 0.00% 0.000us 0.000us 1
cudaStreamIsCapturing 0.10% 2.327us 0.10% 2.327us 2.327us 0.000us 0.00% 0.000us 0.000us 1
cudaMalloc 6.19% 140.606us 6.19% 140.606us 140.606us 0.000us 0.00% 0.000us 0.000us 1
cudaLaunchKernel 4.73% 107.352us 5.30% 120.415us 120.415us 0.000us 0.00% 0.000us 0.000us 1
Lazy Function Loading 0.58% 13.063us 0.58% 13.063us 13.063us 0.000us 0.00% 0.000us 0.000us 1
cudaDeviceSynchronize 19.25% 437.073us 19.25% 437.073us 437.073us 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.271ms
Self CUDA time total: 455.340u
Our custom softmax CUDA kernel has same performance as PyTorch’s softmax, and we have not do any optimization yet.
Conclusion
In this post, by implementing a custom softmax CUDA kernel, I demonstrated how to do warp level reduction and block level reduction efficiently. I also showed how to expose the kernel to Python using PyBind11 and Pytorch cpp_extension module.
The full code of the kernel is available on my GitHub.