WangYu::Space

Study, think, create, and grow. Teach yourself and teach others.

CUDA 014 - MMA 指令的用法详解

分类:CUDA标签: CUDA创建时间:2025-12-26 20:40:17

MMA(Matrix Multiply and Accumulate)是 NVIDIA GPU 提供的专用矩阵乘加指令,用于高效执行 D = A * B + C 运算。随着深度学习和高性能计算对矩阵运算需求的增加,GPU 针对矩阵乘加操作进行了硬件级优化,MMA 指令正是其中的关键组成部分。而且随着硬件的不断升级,MMA 指令也在不断演进,支持更多的数据类型和矩阵规格。本文将介绍一个由 Ampere 架构引入的指令,后续版本引入的新指令,在原理上存在类似之处,因此本文内容对于最新的指令的用法的学习也具有一定的参考价值。

本文主要介绍 mma.sync.aligned.m16n8k16 指令,它可以完成一个 M=16、N=8、K=16 的矩阵乘加运算,支持 fp16 和 bf16 两种输入数据类型,累加结果使用 fp32 存储,是目前应用最广泛的 MMA 指令之一。在实际的矩阵运算中,通常采用分块(tiling)技术将大矩阵分解为小矩阵块进行计算。MMA 指令正是针对这些小矩阵块的乘加运算进行了硬件级优化,可以大幅提升矩阵乘加运算的性能。本文将详细介绍如何使用 MMA 指令来加速矩阵乘加运算,最后会给出详尽的代码示例。

MMA 指令的格式

MMA 指令的格式有很多种,完整的格式说明可以参考 PTX 文档,这里只介绍最常用的几种格式:

mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype  d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32   d, a, b, c;

.ctype   = {.f16, .f32};
.dtype   = {.f16, .f32};

mma 指令支持多种矩阵大小和数据类型的组合。这是一个 warp 级别的指令,执行时需要 warp 内所有 32 个线程协同参与,矩阵数据分散存储在各线程的寄存器中。

让我们逐一解析指令中各个修饰符的含义:

数据分布:对于 m16n8k16 指令,矩阵元素在 32 个线程间的分布如下:

下面是对 mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 指令封装后的函数:

__device__ void
fma(float         & d0, float         & d1, float         & d2, float         & d3,
    uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
    uint32_t const& b0, uint32_t const& b1,
    float const   & c0, float const   & c1, float const   & c2, float const   & c3) {
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
        "{%0,  %1,  %2,  %3},"
        "{%4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10, %11, %12, %13};\n"
        : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
        : "r"(a0),  "r"(a1),  "r"(a2),  "r"(a3),
          "r"(b0),  "r"(b1),
          "f"(c0),  "f"(c1),  "f"(c2),  "f"(c3));
}

寄存器使用说明:因为数据类型为 f32.f16.f16.f32,A/B 矩阵的元素为 16 位(2 字节),C/D 矩阵的元素为 32 位(4 字节)。由于寄存器为 32 位宽,因此:

后文将详细介绍如何将共享内存中的矩阵数据加载到寄存器中,并重点讲解 mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 指令的使用方法。

MMA 指令的数据布局

MMA 指令由 warp 内所有 32 个线程协同执行,因此需要理解矩阵数据如何分布在这些线程的寄存器中。这个分布模式是硬件预定义的,我们必须按照这个模式来组织数据,才能正确调用 MMA 指令。

下图展示了 m16n8k16 指令中,A、B、C 三个矩阵在 warp 内各线程的分布情况:

上图中左侧、上侧和右侧分别表示 A、B 和 C 矩阵,其中 A 是 16x16 矩阵,B 和 C 是 16x8 矩阵。每一个小方格表示矩阵中的一个数据元素,方格中的 T0、T1、T2 等表示该数据所在的线程编号,V0、V1、V2 表示此数据是该线程的第几个数据(指令参数传入的顺序)。

初次看到这个图可能会感到困惑,但其中有规律可循。关键点是矩阵被划分为多个 8×8 的子矩阵块。下图展示了 A 矩阵(16×16)的数据布局:

数据分布规律

  1. 16×16 的 A 矩阵被划分为 4 个 8×8 的子矩阵
  2. 每个 8×8 子矩阵按行优先方式分布在 32 个线程中
  3. 每个线程负责一个子矩阵中的 2 个元素(一行中的相邻元素)
  4. 对于单个 8×8 子矩阵:线程 0-3 存储第 0 行,线程 4-7 存储第 1 行,以此类推
  5. 因为有 4 个子矩阵,每个线程总共存储 8 个元素(4 个子矩阵 × 每个 2 个元素)

重要:只要按照这个分布模式将数据加载到寄存器中,MMA 指令就能正确识别和处理数据。

我们可以按照上述分布规律,将共享内存中的矩阵数据加载到 warp 内各线程的寄存器中。例如,加载 A 矩阵的代码如下:

__shared__ half smem_a[16][16];  // A 矩阵存储在共享内存中

// 每个线程用 4 个寄存器存储 8 个 fp16 元素(每个寄存器打包 2 个元素)
uint32_t a0, a1, a2, a3;

int lane = threadIdx.x % 32;  // 线程在 warp 中的编号 (0-31)
int row = lane / 4;           // 计算该线程负责的行号 (0-7)
int col = lane % 4 * 2;       // 计算该线程负责的列号起始位置 (0, 2, 4, 6)

// 加载 4 个 8×8 子矩阵中的数据
a0 = *reinterpret_cast<uint32_t*>(&smem_a[row][col]);        // 左上子矩阵
a1 = *reinterpret_cast<uint32_t*>(&smem_a[row + 8][col]);    // 左下子矩阵
a2 = *reinterpret_cast<uint32_t*>(&smem_a[row][col + 8]);    // 右上子矩阵
a3 = *reinterpret_cast<uint32_t*>(&smem_a[row + 8][col + 8]);// 右下子矩阵

完整的代码示例可以参考 CUDA MMA 示例代码

B 矩阵是 16x8 的矩阵,可以划分为 2 个 8x8 的子矩阵,其数据布局如下图所示:

这里 B 矩阵的 8x8 子矩阵是按列优先的方式存储在 0-31 号线程的寄存器中的,每个线程存储两个数据。两个 8x8 子矩阵的数据分别存储在每个线程的两个寄存器中。

C 矩阵数据的分布也是类似的,它由两个 8x8 的子矩阵组成,每个子矩阵按行优先的方式存储在 0-31 号线程的寄存器中,每个线程存储两个数据,因为 C 矩阵的数据类型为 f32,每个数据占用一个寄存器。两个 8x8 子矩阵的 4 个数据分别存储在每个线程的 4 个寄存器中。

虽然可以手动计算每个线程的数据地址并加载,但这种方式代码复杂且容易出错。幸运的是,PTX 提供了专门的 ldmatrix 指令,能够高效地将共享内存中的矩阵数据批量加载到寄存器中,并自动完成数据分布。下一节将详细介绍如何使用这个强大的指令。

加载矩阵到寄存器

在调用 mma 指令之前,需要先将 A、B、C 矩阵的数据从共享内存加载到寄存器中。相比手动计算地址和分布逻辑,ldmatrix 指令提供了一种简洁且高效的解决方案

ldmatrix 指令能够一次性从共享内存中读取 8×8 的子矩阵,并自动将数据分散到 warp 内 32 个线程的寄存器中。其工作机制如下

  1. Warp 内前 8 个线程(lane 0-7)各提供一个内存地址,指向矩阵的 8 行
  2. ldmatrix 从每个地址读取 8 个元素(16 字节)
  3. 读取的 8×8=64 个元素按预定规则分布到 32 个线程的寄存器中
  4. 每个线程获得 2 个元素,打包存储在一个 32 位寄存器中

下图展示了 8×8 矩阵加载后在各线程中的分布:

图中数字表示这个位置的元素所在的线程编号,每个线程存储两个数据

目前你只需要知道 ldmatrix 指令可以方便地将共享内存中的 8x8 矩阵加载到寄存器中,关于加载的具体细节,后文会做详细介绍。

ldmatrix 指令格式

ldmatrix 指令完整的格式说明可以参考 PTX 文档,这里只介绍最常用的格式:

ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];

.shape   = {.m8n8};
.num     = {.x1, .x2, .x4};
.ss      = {.shared{::cta}};
.type    = {.b16, .b8};

这里的 shape 表示加载的矩阵大小,.m8n8 表示加载 8x8 的矩阵。对于 16 位数据通常使用 .m8n8,如需加载 16x16,可用 .x4 组合成 4 个 8x8 分块完成加载。

.num 表示加载的矩阵个数,x2 表示加载两个矩阵,x4 表示加载四个矩阵。

.trans 表示加载的矩阵是否需要转置后再存储到寄存器中。对于 A 矩阵来说,因为它是按行优先的方式存储在共享内存中的,因此不需要转置,省略此修饰符即可。对于 B 矩阵来说,因为它是按列优先的方式存储在共享内存中的,因此需要转置后再存储到寄存器中,此时需要加上此修饰符。

.ss 表示加载的数据所在的内存空间,这里通常使用共享内存,因此使用 .shared 修饰符。

.type 表示数据类型,.b16 表示 16 bit 的数据,.b8 表示 8 bit 的数据。这里并不区分数据的具体类型是 fp16 还是 bf16,只要数据占用 16 bit 就可以使用 .b16

r 表示存储加载数据的寄存器,可以是一个寄存器或者多个寄存器组成的列表。如果是 x1 ,则需要提供 1 个寄存器;如果是 x2,则需要提供 2 个寄存器;如果是 x4,则需要提供 4 个寄存器。

p 表示加载数据的共享内存地址,需要传入指向共享内存的指针。使用 m8n8 模式加载一个 8x8 的矩阵时,并不要求 8x8 的矩阵的所有行在内存中连续存储,只需要每一行的数据是连续存储的即可。8 行数据需要 8 个地址。这 8 个地址分别由前 8 个线程来提供(其他线程传入的地址不被使用)。然后整个 warp 的 32 个线程协同完成 8x8 矩阵的加载,并将所有数据分散到 32 个线程提供的寄存器中。

不同的 .num 参数对应的寄存器地址和线程的映射关系如下表所示:

.numThreads 0–7Threads 8–15Threads 16–23Threads 24–31
.x1addr0–addr7
.x2addr0–addr7addr8–addr15
.x4addr0–addr7addr8–addr15addr16–addr23addr24–addr31

使用 ldmatrix 加载 8×8 矩阵的示例代码:

__shared__ half smem_a[8][8];  // 共享内存中的 8×8 矩阵

int lane = threadIdx.x % 32;   // 当前线程在 warp 中的编号
int row = lane % 8;            // 前 8 个线程(lane 0-7)各提供一行的地址
int col = 0;                   // 从第 0 列开始读取
uint32_t* addr = reinterpret_cast<uint32_t*>(&smem_a[row][col]);
uint32_t a0;  // 存储 2 个 fp16 元素(打包在 1 个 32 位寄存器中)

asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
    : "=r"(a0)     // 输出: 寄存器 a0
    :  "r"(addr)); // 输入: 共享内存地址

执行效果:Warp 内 32 个线程协同执行,将 8×8=64 个元素加载到各自的寄存器中,每个线程获得 2 个元素,打包存储在一个 32 位寄存器中。

共享内存地址计算

ldmatrix 的一个重要特性是:不要求矩阵在内存中完全连续。唯一的要求是每一行内部的元素连续。这种灵活性使得我们可以从大矩阵中加载子矩阵,或者处理有 padding 的矩阵。

地址提供机制

例如,从下面的 8×16 矩阵中加载右侧的 8×8 子矩阵时,线程 0-7 提供的地址如图所示:

图中数字表示线程编号,标记位置是该线程提供的内存地址(指向该行起始位置)

灵活性:加载的 8 行数据不一定要形成规则的矩阵块,只要每行内部连续即可。例如下图所示,可以从不同位置加载 8 行:

加载更大的矩阵:要加载 16×16 矩阵,使用 .x4 修饰符一次加载 4 个 8×8 块:

地址分布如下图所示:

加载 16×16 矩阵的代码示例:

__shared__ half smem_a[16][16];  // 共享内存中的 16×16 矩阵

int lane = threadIdx.x % 32;     // 线程编号 (0-31)
int row = lane % 16;             // 行号 (0-15),循环使用
int col = (lane / 16) * 8;       // 列偏移 (0 或 8)
uint32_t* addr = reinterpret_cast<uint32_t*>(&smem_a[row][col]);
uint32_t a0, a1, a2, a3;         // 4 个寄存器,对应 4 个 8×8 块

asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
    : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)  // 输出: 4 个寄存器
    :  "r"(addr));                             // 输入: 地址

数据分布:执行完成后,例如 lane 0 的线程中:

加载时转置

回顾 MMA 指令中 A 和 B 矩阵的数据布局差异:

图中小数字表示线程编号,大数字表示 8×8 子矩阵编号

关键区别

为了适应这种布局差异,ldmatrix 提供了 .trans 修饰符,允许在加载时进行转置。下图对比了转置与非转置加载 8×8 矩阵时的数据分布:

图中深色部分为线程 0-7 提供的地址,数字表示该元素所在的线程编号

代码示例 - 转置 vs 非转置加载:

__shared__ half smem[8][8];  // 共享内存中的 8×8 矩阵

int lane = threadIdx.x % 32;  
int row = lane;    // 前 8 个线程(0-7)各提供一行的地址
int col = 0;       // 从第 0 列开始
uint32_t* addr = reinterpret_cast<uint32_t*>(&smem[row][col]);

uint32_t a0, b0;

// 非转置加载 - 适用于行优先存储的 A 矩阵
asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
    : "=r"(a0) : "r"(addr));

// 转置加载 - 适用于列优先存储的 B 矩阵
asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
    : "=r"(b0) : "r"(addr));

MMA 指令中 A/B 矩阵的加载

场景设定:假设共享内存中

MMA 指令要求的数据分布方式如下图:

图中深色标记为线程提供的地址位置

加载策略

A 矩阵加载代码

加载 16×16 的 A 矩阵(MxK):

__shared__ half smem_a[16][16];  // MxK 矩阵

int lane = threadIdx.x % 32;
int row_a = lane % 16;        // 行号 (0-15)
int col_a = lane / 16 * 8;    // 列偏移 (0 或 8)
uint32_t* addr_a = reinterpret_cast<uint32_t*>(&smem_a[row_a][col_a]);
uint32_t a0, a1, a2, a3;

// 非转置加载,一次加载 4 个 8×8 块
asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
    : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
    :  "r"(addr_a));

地址计算

B 矩阵加载代码

加载 16×8 的 B 矩阵(K×N):

__shared__ half smem_b[16][8];  // KxN 矩阵

int row_b = lane;    // 线程 0-15 分别提供 16 行的地址
int col_b = 0;       // 从第 0 列开始
uint32_t* addr_b = reinterpret_cast<uint32_t*>(&smem_b[row_b][col_b]);
uint32_t b0, b1;

// 转置加载,一次加载 2 个 8×8 块
asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
    : "=r"(b0), "=r"(b1)
    :  "r"(addr_b));

地址计算

B 矩阵的替代存储布局

如果 B 矩阵在共享内存中以 N×K 格式存储,则可以直接使用非转置加载。下图展示了这种情况下的线程地址分布:

加载 N×K 布局的 B 矩阵代码:

__shared__ half smem_b[8][16];  // N×K 矩阵

int lane = threadIdx.x % 32;
int row = lane % 8;        // 行号 (0-7),线程 0-7 和 8-15 分别提供 8 行
int col = lane / 8 * 8;    // 列偏移, 0 或 8
uint32_t* addr = reinterpret_cast<uint32_t*>(&smem_b[row][col]);
uint32_t b0, b1;

// 非转置加载(因为已是行优先存储)
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
    : "=r"(b0), "=r"(b1)
    :  "r"(addr));

关键点:当 B 矩阵以 N×K 存储时,已经是转置形式,无需再加 .trans 修饰符。

A/B 矩阵加载总结

在使用 MMA 指令前,需要根据矩阵在共享内存中的存储格式,合理选择 ldmatrix 的转置与否。一种简单的思维方式是考虑矩阵乘法的数学定义。对于 C = A × B,其中 C[0,0] 的计算涉及 A 的第 0 行和 B 的第 0 列。在加载完成 A 和 B 矩阵后,应该使得 A 的第 0 行数据分布在 0-3 号线程的寄存器中,而 B 的第 0 列数据也应该分布在 0-3 号线程的寄存器中。我猜想 MMA 这样要求数据分布的原因是为了让每个线程在计算 C 矩阵的元素时,可以直接使用自己寄存器中的数据,而不需要跨线程访问。

因此,在思考如何加载 A 和 B 矩阵时,可以从 C 矩阵的计算需求出发,确保每个线程在计算 C 矩阵元素时,能够直接使用自己寄存器中的数据。这样你就可以根据 B 矩阵在共享内存中的存储模式,选择是否需要转置加载。

存储结果到共享内存

MMA 指令执行完成后,结果矩阵 D 的数据分布与 C 矩阵一致,存储在各线程的寄存器中。若需要将结果写回共享内存,可以使用 stmatrix 指令。

stmatrix 指令格式

stmatrix.sync.aligned.shape.num{.trans}{.ss}.type [p], r;

.shape  = {.m8n8, .m16n8};     // 矩阵形状
.num    = {.x1, .x2, .x4};     // 存储的块数量
.ss     = {.shared{::cta}};    // 目标内存空间
.type   = {.b16, .b8};         // 元素位宽

stmatrixldmatrix 的逆操作,将寄存器中的数据写回共享内存。其参数含义、线程地址映射关系与 ldmatrix 相同。

重要说明

更多关于 ldmatrix 和 stmatrix 的使用示例,可参考 CUDA MMA 示例代码

使用 MMA 实现通用矩阵乘加

现在我们已掌握了:

  1. MMA 指令的调用方式
  2. 矩阵数据的分布规律
  3. 使用 ldmatrix 加载数据的方法

接下来介绍如何将这些技术组合起来,实现高效的通用矩阵乘法(GEMM)。

Tiling 策略:矩阵乘法通常采用分块(tiling)技术。对于 D = A × B + C 运算,其中 A(M×K),B(K×N),C/D(M×N),我们可以将计算分解为:

分块策略

关于矩阵乘法的 tiling 技术详情,可参考我之前的文章 CUDA 矩阵乘法优化。本文重点讲解如何使用 MMA 指令计算这些小块

MMA 块的大小匹配

MMA 指令支持的矩阵大小是固定的。本文介绍的 mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 支持 M=16,N=8,K=16。

简单方案:直接设置 Bm=16,Bn=8,Bk=16,但这会导致 tiling 粒度太小,无法充分利用 GPU 资源。

实用方案:设置 Bm=128,Bn=128,Bk=32(均为 MMA 尺寸的整数倍),然后在线程块内继续分块。下图展示了这种分块方式:

多层分块逻辑:就像外层将大矩阵分为 Bm×Bn 块一样,线程块内部也需要将 Bm×Bn 块继续划分为更小的 16×8 块,直到可以用 MMA 指令计算。

Warp 级并行分配

因为 MMA 是 warp 级指令,我们可以让多个 warp 并行处理不同的 16×8 块:

实现策略:线程块中启动 num_warps 个 warp,让它们并行处理 Bm×Bn 块中的多个 16×8 小块。下面是一种可能的实现:

int mma_m = Bm / 16;   // M 维度上的 MMA 块数量
int mma_n = Bn / 8;    // N 维度上的 MMA 块数量
int mma_k = Bk / 16;   // K 维度上的迭代次数

// 多个 warp 并行处理 mma_m × mma_n 个 16×8 块
for (int idx = warp_id; idx < mma_m * mma_n; idx += num_warps) {
    int m = idx / mma_n;  // 当前块在 M 维度的位置
    int n = idx % mma_n;  // 当前块在 N 维度的位置

    // 当前 warp 计算位于 (m, n) 位置的 16×8 块
}

分工逻辑:每个 warp 沿着 M×N 平面以步长 num_warps 迭代,处理自己负责的 16×8 块。

Warp 内部计算

在每个 warp 内部,需要沿 K 维度迭代,从共享内存加载 A 和 B 子块,然后调用 MMA 指令计算 16×8 块的乘加:

int row = m * 16;  // 当前 16×8 块的起始行
int col = n * 8;   // 当前 16×8 块的起始列

float c_regs[4] = {0.0f};  // 初始化累加寄存器

// 沿 K 维度迭代
for (int k = 0; k < mma_k; ++k) {
    uint32_t a_regs[4];  // A 矩阵寄存器 (16×16 的 8 个 fp16)
    uint32_t b_regs[2];  // B 矩阵寄存器 (16×8 的 4 个 fp16)

    // 从共享内存加载 A 矩阵子块 (16×16)
    int a_row = row + (lane % 16);
    int a_col = k * 16 + (lane / 16) * 8;
    uint32_t* addr = reinterpret_cast<uint32_t*>(&smem_a[a_row][a_col]);
    ldmatrix_4x(addr, a_regs[0], a_regs[1], a_regs[2], a_regs[3]);

    // 从共享内存加载 B 矩阵子块 (16×8)
    int b_row = k * 16 + lane;
    int b_col = col;
    uint32_t* baddr = reinterpret_cast<uint32_t*>(&smem_b[b_row][b_col]);
    ldmatrix_2x(baddr, b_regs[0], b_regs[1]);

    // 调用 MMA 指令进行矩阵乘加 (D = A × B + C)
    fma(
        c_regs[0], c_regs[1], c_regs[2], c_regs[3],  // D 和 C(输出/输入)
        a_regs[0], a_regs[1], a_regs[2], a_regs[3],  // A
        b_regs[0], b_regs[1],                         // B
        c_regs[0], c_regs[1], c_regs[2], c_regs[3]   // C(累加)
    );
}
// 此时 c_regs 中存储了该 16×8 块的最终结果

计算流程

  1. 初始化累加寄存器 c_regs 为 0
  2. 沿 K 维度迭代:
    • 从共享内存加载 A 和 B 的当前分块
    • 调用 MMA 指令,结果累加到 c_regs
  3. K 维度迭代完成后,c_regs 中包含最终结果
  4. c_regs 写回共享内存或全局内存

完整的 MMA GEMM 实现可参考:CUDA MMA GEMM 示例代码,在我实现的简化版本 flash-attention 的代码中也有使用 MMA 指令的例子。

避免 bank conflict

在使用共享内存时,bank conflict 是一个常见的性能瓶颈。共享内存通常划分为 32 个 bank,如果在一个 wavefront 内多个线程同时访问同一个 bank,就会发生冲突,导致访问串行化,影响性能。ldmatrixstmatrix 指令本质上也是使用基本的共享内存访问指令实现的,因此也可能受到 bank conflict 的影响。在使用 ldmatrixstmatrix 时,需要考虑如何避免 bank conflict。

通用的做法就是使用 swizzle 技术对共享内存中的数据进行重新布局,让一个 ldmatrix 操作中读取的 8 行数据映射到不同的 bank 上。以本文介绍的 m8n8 指令为例,ldmatrix 可以从共享内存中加载 8x8 矩阵。如果使用 x4 模式加载 16×16 矩阵,实际上就是重复了 4 次加载 8 行数据的过程。

我们已经知道,ldmatrix 指令本质上是从 8 个地址中分别读取 8 个两字节的元素,一共加载 128 字节的数据。而 32 个 bank 每个 bank 每次可以提供 4 字节的数据,因此在理想情况下,8 行数据可以分布在不同的 bank 上,这样就不会发生 bank conflict。所以,我们只需要按照 16 字节作为一个单位,对共享内存中的数据进行重新布局,就可以避免 bank conflict。关于 swizzle 技术的详细介绍,可以参考我之前的文章 Swizzle 的工作原理

比如,对于一个 8x64 的 fp16 矩阵,我们可以按照下面的方式进行 swizzle:

总结

本文系统介绍了 CUDA MMA 指令的使用方法,包括:

  1. MMA 指令基础:指令格式、参数含义、数据类型和寄存器使用
  2. 数据布局规则:A、B、C 矩阵在 warp 内 32 个线程中的分布模式
  3. ldmatrix 指令:高效加载共享内存数据到寄存器,支持转置加载
  4. 地址计算:如何为不同线程计算共享内存地址,处理各种存储布局
  5. 实际应用:如何将 MMA 指令集成到通用矩阵乘法(GEMM)中
  6. 性能优化:避免共享内存 bank conflict 的技巧

掌握 MMA 指令能够显著提升矩阵计算密集型应用的性能,希望本文能帮助你更好地理解和使用这一强大的 GPU 特性。

评论 (评论内容仅博主可见,不会公开显示)