使用 TileLang 编写高性能内核:从 GEMM 到 MLA

TileLang 让你能够通过 Python 编写 GPU 内核,并对 Tile、流水线(pipeline)和线程束(warp)进行显式控制。无论是简单的 GEMM 还是 DeepSeek 的 MLA 解码,均可轻松实现。

使用 TileLang 编写高性能内核:从 GEMM 到 MLA

编写 GPU 内核时,你实际上是处在一条技术光谱上。一端是 Triton:编写速度快,但编译器会为你处理大部分内存布局和共享内存的决策。另一端是 CUTLASS / CuTe:提供完全控制力,代价是需要编写大量的模板代码。TileLang 则位于中间。你可以用 Python 编写逻辑,但需要显式定义共享内存的内容、流水线的阶段以及线程束(warp)如何分工——而剩下的工作则由布局推导(layout inference)阶段来完成。

在这篇文章中,我们将介绍其思维模型,编写一个 GEMM,然后深入到真实的生产级内核:DeepSeek 的 MLA 解码(decode),这里才能体现出真正有趣的决策。我们的目的不是面面俱到,而是展示你如何思考“分块”(tiles),以及 TileLang 如何默默地为你完成那些困难的部分。最后,我们将分享一个典型的生产案例——一个性能提升并非首要目标的内核。

思维模型

核心思想概括为三点:

  • Tile 是头等对象。 一个数据块(例如 block_M × block_K)由线程块、线程束或单个线程拥有并处理。你不再像在 Triton 中那样完全停留在线程块层面,也不需要像在 CUDA 中那样手动管理每个线程。
  • 你自己决定内存层级中的缓冲区位置。 你明确声明哪些数据进入共享内存 (T.alloc_shared),哪些进入寄存器 (T.alloc_fragment),以及哪些是线程局部的。这是它与 Triton 最大的区别,Triton 将共享内存的分配和调度隐藏在编译器中。
  • 编译器推导线程映射。 一旦你定义了 tile 的位置以及在其上运行的操作(复制、gemm、归约),“布局推导”阶段就会将其并行化分配给线程,并计算出寄存器和共享内存的布局。虽然你可以手动重写,但大多数情况下无需干预。这是该工具的核心功能——当你看到 MLA 的例子时,就会明白其价值。

如果你有 Triton 的使用经验,可以参考以下映射表:

 TritonTileLang
粒度线程块 + 隐式向量化tile(线程块 / 线程束 / 线程)
共享内存由编译器管理显式 alloc_shared + copy
布局编译器决定推导得出,也可手动标注
流水线tl.range + 编译器显式 T.Pipelined(num_stages=)
Tensor Coretl.dotT.gemm,可选择线程束策略
后端NVIDIA (主要) / AMDNVIDIA / AMD / CPU / WebGPU / CuTeDSL,以及 Ascend & MUSA 分支

简而言之:如果你想在不编写 CUTLASS 的前提下,精确控制分块、流水线深度和线程束分区,TileLang 是最佳选择。对于简单的逐元素操作或轻量级融合,Triton 仍然更快。

环境配置

plaintext
1conda create -n tilelang python=3.10 -y
2conda activate tilelang
3pip install tilelang                 # 预编译包,最快路径

如果你需要修改编译器代码,建议从源码构建(需要本地 LLVM/CUDA 工具链):

plaintext
1git clone --recursive https://github.com/tile-ai/tilelang.git
2cd tilelang && pip install -r requirements-dev.txt
3pip install -e . -v --no-build-isolation

编写一个 GEMM

我们从经典的 C = ReLU(A @ B) 内核开始。它很小,但涵盖了所有关键的原语:显式缓冲区、并行复制、软件流水线、Tensor Core 调用以及 L2 swizzle。

plaintext
1import tilelang
2import tilelang.language as T
3import torch
4
5@tilelang.jit
6def matmul(M, N, K, block_M, block_N, block_K,
7           dtype="float16", accum_dtype="float"):
8
9    @T.prim_func
10    def matmul_relu_kernel(
11        A: T.Tensor((M, K), dtype),
12        B: T.Tensor((K, N), dtype),
13        C: T.Tensor((M, N), dtype),
14    ):
15        # 网格维度:(#blocks along N, #blocks along M); 每个块 128 个线程
16        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M),
17                      threads=128) as (bx, by):
18
19            # 显式定义每个 tile 的存储位置
20            A_shared = T.alloc_shared((block_M, block_K), dtype)         # 共享内存
21            B_shared = T.alloc_shared((block_K, block_N), dtype)
22            C_local  = T.alloc_fragment((block_M, block_N), accum_dtype) # 寄存器累加器
23
24            T.use_swizzle(panel_size=4, order="col")   # 可选:更好的 L2 重用
25            T.clear(C_local)                           # 清零累加器
26
27            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
28                T.copy(A[by * block_M, ko * block_K], A_shared)   # 全局 -> 共享
29                T.copy(B[ko * block_K, bx * block_N], B_shared)
30                T.gemm(A_shared, B_shared, C_local)               # tile 级 MMA
31
32            for i, j in T.Parallel(block_M, block_N):             # 融合 ReLU
33                C_local[i, j] = T.max(C_local[i, j], 0)
34
35            T.copy(C_local, C[by * block_M, bx * block_N])        # 写回
36
37    return matmul_relu_kernel
38
39
40M = N = K = 1024
41kernel = matmul(M, N, K, block_M=128, block_N=128, block_K=64)
42a = torch.randn(M, K, device="cuda", dtype=torch.float16)
43b = torch.randn(K, N, device="cuda", dtype=torch.float16)
44c = torch.empty(M, N, device="cuda", dtype=torch.float16)
45kernel(a, b, c)
46
47torch.testing.assert_close(c, torch.relu(a @ b), rtol=1e-2, atol=1e-2)
48print("gemm ok")

代码各部分的意义:

  • 三个缓冲区,三个层级。 A_shared 和 B_shared 位于共享内存;C_local 位于寄存器中。这是 GEMM 的标准做法,区别在于这行代码是你亲手写下的,而不是像 Triton 那样由编译器暗中处理。
  • T.copy 是并行复制的语法糖。 它展开为 T.Parallel 风格的移动,编译器从中推导出向量化、合并后的全局到共享内存的传输。当它位于 T.Pipelined 中时,会自动变为 cp.async。
  • T.Pipelined(extent, num_stages=N) 是软件流水线。 num_stages=3 表示三倍缓冲——在计算当前 K-tile 的同时,ko+1 和 ko+2 的载入已经在进行中。在 Triton 中这需要编译器标志,而在 TileLang 中它就是循环的一部分,更易于理解。
  • T.gemm(A, B, C) 是 tile 级矩阵乘法。 它在 NVIDIA 平台上降低为 CuTe/MMA,在 AMD 上则匹配对应的内联指令。它还接受 transpose 参数和 GemmWarpPolicy,后者控制线程束如何拆分输出 tile。这一点对 MLA 非常关键。
  • T.use_swizzle 重新调度线程块,使 L2 缓存中相邻的块在执行时间上也更接近,通常能带来几个百分点的带宽提升。

下方的示意图将这些内容映射到了硬件上,它精确展示了 TileLang 将哪些本由编译器控制的权限交回给了你。

GEMM in TileLang — 你亲自定义了层级中的每个缓冲区。A_shared / B_shared 位于共享内存,C_local 在线程束 W0–W3 的寄存器中累加,K-loop 流水线 (num_stages=3) 将 cp.async 预取与 gemm 计算重叠。

常用原语

只需掌握少量词汇即可编写大部分内核:

  • 分配: T.alloc_shared, T.alloc_fragment (寄存器), T.alloc_local。
  • 移动与初始化: T.copy(src, dst)(跨层级);T.clear, T.fill。
  • 计算: T.gemm(...); T.Parallel(d0, d1, ...) 用于逐元素循环(布局推导入口);T.reduce_max / T.reduce_sum;标量数学运算如 T.exp, T.exp2, T.max, T.infinity。
  • 调度: T.Pipelined(extent, num_stages=), T.use_swizzle(...), T.annotate_layout(...)(用于处理 bank 冲突或自定义 swizzle)。
  • 动态形状: M = T.dynamic("m") 以避免针对每个形状重新编译。

验证与调试

获取编译器生成的源代码:

plaintext
1print(kernel.get_kernel_source())     # 生成的 CUDA / HIP

获取性能数据:

plaintext
1profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
2print(f"latency: {profiler.do_bench()} ms")

实战:MLA 解码

DeepSeek 的 MLA 解码内核是 TileLang 的最佳展示窗口。其性能达到了 FlashMLA 在 H100 上的水平(在 batch 64/128 fp16 下,明显优于 Triton 和 FlashInfer),代码仅约 80 行 Python。MLA 的核心难点不在数学计算,而在寄存器压力

FlashAttention 系列内核的结构基本一致:每个 query 块遍历 key/value 块,维护当前的 max 和分母。然而对于 MLA,头维度非常大(查询和键宽 576,值宽 512),这意味着累加器 acc_o = [block_M, 512] 必须在整个 KV 循环中驻留在寄存器中。

在 Hopper 架构上,最快路径是 wgmma.mma_async,它将 4 个线程束(128 线程)捆绑为一个线程束组(warpgroup),要求 M 至少为 64。一个线程束组持有 64 × 512 的累加器会溢出寄存器,导致性能大幅下降。

MLA decode in TileLang — 将 acc_o 拆分到两个线程束组。WG0 和 WG1 各计算 Q·K^T 的一半,通过 S_shared 交换得分,然后分别计算各自的 P·V 输出分块。所有的记账工作(形状、交换逻辑)都通过你标注的 FullCol 策略由布局推导自动生成。

解决方案是将输出拆分到两个线程束组。 由于 M 不能小于 64,只能通过维度轴拆分:两个线程束组各持有一半的累加器(64 × 256)。这引入了第二个问题:P @ V 计算需要完整的 acc_s,而 Q @ K 阶段每个线程束组只计算了一半。解决方法是通过共享内存交换数据:在 Q @ K 期间,每个线程束组将自己的一半 acc_s 写入共享内存并读取对方的,从而确保两者都持有完整的 acc_s。

在 CuTe 中,你需要手动编写布局、swizzle、Tensor Core 对齐和生产者/消费者同步。而在 TileLang 中,只需通过标注 policy 即可由布局推导完成。

布局推导的作用:

  1. P @ V 上的 policy=FullCol 意味着每个线程束组需要完整的 acc_s。
  2. 约束向后传播到 staging buffer:S_shared 自动变为 [block_M, block_N]。
  3. 向前传播到 Q @ K:每个线程束组的得分分块为 [block_M, block_N/2]。

你只需编写数学逻辑,形状、swizzle 布局和专门的同步代码全由推导生成。

AtlasCloud 的生产实践:RMSNorm

我们在 AtlasCloud 的 Wan 视频生成 VAE 中使用了一个 TileLang 编写的 RMSNorm。原有手写优化内核仅支持部分隐藏维度(D ∈ {96, 192, 384}),而新配置需要 {160, 256, 320, 512, 640, 1024} 等通道宽度。TileLang 的通用性完美填补了这一空白。

TileLang 内核实现了一个与手写版本接口完全一致的 RMSNorm + SiLU 算子,支持任何 32 的倍数维度。

优势:

  • 零修改接入: 接口、数学逻辑与 eps 参数与现有代码完全兼容。
  • 性能提升: 在 Attention-block 的 RMSNorm 中,从 42 μs 提升至 20 μs(约 2 倍快)。
  • 端到端加速: VAE 在生产分辨率下(720×1280, 21 帧)编码速度提升约 1.79 倍,解码提升约 1.78 倍。

总结

TileLang 的精妙之处在于,复杂的逻辑判断保留在你的脑海中,而不是转化为大量的样板代码。你决定如何拆分线程束、分配缓冲区以及设置流水线深度——剩下的寄存器布局、swizzle 逻辑和复杂的生产者/消费者同步由布局推导自动完成。你只需选择策略并编写数学逻辑,这正是它能让 80 行的 MLA 内核与手写 CUTLASS 内核同台竞技的原因。

最新模型

300+ 模型,即刻开启,

探索全部模型

Join our Discord community

Join the Discord community for the latest model updates, prompts, and support.