在 TileLang 中編寫高效能核心:從 GEMM 到 MLA

TileLang 讓您能夠用 Python 編寫 GPU kernel,並對 tile、pipeline 和 warp 進行明確的控制。從簡單的 GEMM 到 DeepSeek 的 MLA 解碼皆可勝任。

在 TileLang 中編寫高效能核心:從 GEMM 到 MLA

如果您編寫 GPU Kernel,您其實是在一個光譜上遊走。光譜的一端是 Triton:編寫快速,但編譯器會為您處理大部分的佈局(layout)和共享記憶體(shared memory)決策。另一端是 CUTLASS / CuTe:擁有完全掌控權,但代價是大量的模板程式碼(template machinery)。TileLang 則位於兩者之間。您使用 Python 編寫程式,但必須明確指定什麼資料存放在共享記憶體中、管線(pipeline)如何分級、以及 Warp 如何劃分工作——剩下的部分則透過「佈局推斷(layout inference)」過程自動填入。

這篇文章將涵蓋心智模型,編寫一個 GEMM,然後構建一個真正的生產環境 Kernel:DeepSeek 的 MLA(Multi-Head Latent Attention)解碼,在這裡,有趣的決策才會真正體現出來。我們的目標不是面面俱到,而是展示您應該如何思考 Tile(分塊),以及 TileLang 如何在幕後輕鬆處理棘手的部分。最後,我們將分享一個來自生產環境的典型案例——一個速度提升並非唯一優勢的 Kernel。

心智模型

核心理念可歸納為三點:

  • Tile 是第一類物件(first-class object)。 一塊有形狀的資料區塊(例如 block_M × block_K)由一個執行緒區塊(thread block)、Warp 或執行緒來擁有與操作。您不必再像 Triton 那樣僅從執行緒區塊層級思考,也不必像在 CUDA 中那樣手動管理單個執行緒。
  • 由您親自放置記憶體層級中的緩衝區。 您需宣告哪些內容放入共享記憶體(
    text
    1T.alloc_shared
    )、哪些放入暫存器(
    text
    1T.alloc_fragment
    ),以及哪些屬於執行緒本地。這是與 Triton 最大的區別,Triton 將共享記憶體的配置和分級隱藏在編譯器中。
  • 編譯器負責推斷執行緒映射。 一旦您指定了 Tile 的位置以及要在其上執行的操作(複製、GEMM、歸約),「佈局推斷」過程會將其並行化到執行緒中,並規劃暫存器和共享記憶體的佈局。雖然您可以根據需要進行覆寫,但大多數情況下不需要。這個過程是支撐整個功能的關鍵——當我們深入 MLA 時,您就會明白為什麼。

如果您是從 Triton 轉過來的,對照表大致如下:

 TritonTileLang
粒度執行緒區塊 + 隱式向量化Tile(塊 / Warp / 執行緒)
共享記憶體由編譯器管理顯式
text
1alloc_shared
+
text
1copy
佈局由編譯器決定自動推斷,但可手動標註
管線化
text
1tl.range
+ 編譯器
顯式
text
1T.Pipelined(num_stages=)
Tensor Core
text
1tl.dot
text
1T.gemm
搭配可選的 Warp 策略
後端NVIDIA (主要) / AMDNVIDIA / AMD / CPU / WebGPU / CuTeDSL,外加 Ascend 與 MUSA 分支

簡而言之:如果您想在不編寫 CUTLASS 的情況下,精細控制區塊劃分、管線深度和 Warp 分割,TileLang 是最佳選擇。對於簡單的逐元素運算或輕量級融合,Triton 仍然是更快捷的工具。

設定環境

plaintext
1conda create -n tilelang python=3.10 -y
2conda activate tilelang
3pip install tilelang                 # 預先構建的 wheel,最簡單的方式

如果您打算修改編譯器流程,請改用原始碼構建(您需要本地的 LLVM/CUDA 工具鏈):

plaintext
1git clone --recursive https://github.com/tile-ai/tilelang.git
2cd tilelang && pip install -r requirements-dev.txt
3pip install -e . -v --no-build-isolation

編寫一個 GEMM

我們從每個人都熟悉的 Kernel 開始:

text
1C = ReLU(A @ B)
。它雖然簡單,但涵蓋了所有重要的原語:顯式緩衝區、並行複製、軟體管線化、Tensor Core 呼叫以及 L2 swizzle。

plaintext
1import tilelang
2import tilelang.language as T
3import torch
4
5@tilelang.jit
6def matmul(M, N, K, block_M, block_N, block_K,
7           dtype="float16", accum_dtype="float"):
8
9    @T.prim_func
10    def matmul_relu_kernel(
11        A: T.Tensor((M, K), dtype),
12        B: T.Tensor((K, N), dtype),
13        C: T.Tensor((M, N), dtype),
14    ):
15        # 網格維度:(沿 N 的塊數, 沿 M 的塊數);每個區塊 128 個執行緒
16        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M),
17                      threads=128) as (bx, by):
18
19            # 明確指定每個 Tile 的位置
20            A_shared = T.alloc_shared((block_M, block_K), dtype)         # 共享記憶體
21            B_shared = T.alloc_shared((block_K, block_N), dtype)
22            C_local  = T.alloc_fragment((block_M, block_N), accum_dtype) # 暫存器累加器
23
24            T.use_swizzle(panel_size=4, order="col")   # 選項:更好的 L2 重用
25            T.clear(C_local)                           # 清空累加器
26
27            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
28                T.copy(A[by * block_M, ko * block_K], A_shared)   # 全局 -> 共享
29                T.copy(B[ko * block_K, bx * block_N], B_shared)
30                T.gemm(A_shared, B_shared, C_local)               # Tile 層級 MMA
31
32            for i, j in T.Parallel(block_M, block_N):             # 融合 ReLU
33                C_local[i, j] = T.max(C_local[i, j], 0)
34
35            T.copy(C_local, C[by * block_M, bx * block_N])        # 回寫
36
37    return matmul_relu_kernel
38
39
40M = N = K = 1024
41kernel = matmul(M, N, K, block_M=128, block_N=128, block_K=64)
42a = torch.randn(M, K, device="cuda", dtype=torch.float16)
43b = torch.randn(K, N, device="cuda", dtype=torch.float16)
44c = torch.empty(M, N, device="cuda", dtype=torch.float16)
45kernel(a, b, c)
46
47torch.testing.assert_close(c, torch.relu(a @ b), rtol=1e-2, atol=1e-2)
48print("gemm ok")

各部分的作用如下:

  • 三個緩衝區,三個層級。
    text
    1A_shared
    text
    1B_shared
    存於共享記憶體;
    text
    1C_local
    存於暫存器。暫存器中的累加器,以及透過共享記憶體分級的操作數——這是 GEMM 的標準配方,唯一的區別在於這裡是由「您」親自寫下。這就是與 Triton 的核心差異。
  • text
    1T.copy
    是並行複製的語法糖。 它會展開為
    text
    1T.Parallel
    風格的移動,編譯器會從中導出向量化、合併後的「全局→共享」傳輸。當
    text
    1copy
    位於
    text
    1T.Pipelined
    內部時,它會自動變為
    text
    1cp.async
  • text
    1T.Pipelined(extent, num_stages=N)
    是軟體管線。
    text
    1num_stages=3
    表示三級緩衝——在計算 K-tile
    text
    1ko
    的同時,
    text
    1ko+1
    text
    1ko+2
    的負載已經在傳輸中。在 Triton 中,這是一個編譯參數;而在這裡,它只是迴圈的一部分,更容易理解。
  • text
    1T.gemm(A, B, C)
    是 Tile 層級的矩陣乘法。 它在 NVIDIA 上會降低為 CuTe/MMA,在 AMD 上則對應到相應的指令集。它還接受
    text
    1transpose_A / transpose_B
    以及一個
    text
    1policy=T.GemmWarpPolicy.*
    參數,用來控制 Warp 如何分割輸出 Tile。請記住這個策略參數,它是我們之後提到 MLA 時的重點。
  • text
    1T.use_swizzle
    會重新排列執行緒區塊的排程,使 L2 中相鄰的區塊在時間上也能連續執行,通常能帶來幾個百分點的頻寬提升。

下圖將這一切對應到硬體上。值得對照程式碼閱讀,因為標註的區域正是 TileLang 將 Triton 隱藏的控制權交還給您的地方。

GEMM in TileLang — 您親自將每個緩衝區放置在記憶體層級中。A_shared / B_shared 位於共享記憶體,C_local 在 Warp W0–W3 的暫存器中累積,K 迴圈管線(num_stages=3)將 cp.async 預取與當前的 gemm 計算重疊。

您將經常用到的幾個原語

您可以使用有限的詞彙編寫大多數 Kernel:

  • 配置:
    text
    1T.alloc_shared
    text
    1T.alloc_fragment
    (暫存器)、
    text
    1T.alloc_local
  • 移動與初始化: 任意兩層級間的
    text
    1T.copy(src, dst)
    text
    1T.clear
    text
    1T.fill
  • 計算:
    text
    1T.gemm(...)
    ;用於逐元素迴圈的
    text
    1T.Parallel(d0, d1, ...)
    (這是佈局推斷的入口點);
    text
    1T.reduce_max
    /
    text
    1T.reduce_sum
    ;如
    text
    1T.exp
    text
    1T.exp2
    text
    1T.max
    text
    1T.infinity
    等標量數學運算。
  • 排程:
    text
    1T.Pipelined(extent, num_stages=)
    text
    1T.use_swizzle(...)
    、以及當您需要特定佈局時使用的
    text
    1T.annotate_layout(...)
    (用於避免 Bank Conflict、自訂 Swizzle)。
  • 動態形狀:
    text
    1M = T.dynamic("m")
    ,這樣您就不需要針對每個形狀重新編譯(在某些版本中稱為
    text
    1T.symbolic
    )。

驗證結果

您通常會需要兩件事。查看編譯器實際輸出的程式碼:

plaintext
1print(kernel.get_kernel_source())     # 產生的 CUDA / HIP

以及進行效能分析:

plaintext
1profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
2print(f"latency: {profiler.do_bench()} ms")

text
1T.print(buf)
會在 Kernel 內部列印出一個 Tile,而專案中的
text
1examples/plot_layout
可以繪製記憶體佈局,這在追蹤 Bank Conflict 或檢查 Swizzle 時非常有用。

實際案例:MLA 解碼

GEMM 展示了機制,而下一個案例則展示了這些機制為何重要。我們將剖析 DeepSeek 的 MLA(Multi-Head Latent Attention)解碼 Kernel,因為這是 TileLang 體現其價值的最佳範例。TileLang 的參考實作在大約 80 行 Python 程式碼中,達到了與 FlashMLA 在 H100 上的相當效能(在 fp16、batch 64/128 下進行基準測試,輕鬆領先 Triton 和 FlashInfer)。有趣的問題是它是如何做到的,因為 MLA 的難點不在於數學,而在於暫存器壓力。

讓我們複習一下大家熟知的迴圈。每個 FlashAttention 系列的 Kernel 都有相同的形狀。對於每個 Query 區塊,您串流處理 Key/Value 區塊並維持一個移動的最大值和分母,這樣完整的計分矩陣就不會存入記憶體:

plaintext
1# acc_s : [block_M, block_N]  此 KV 區塊的計分
2# acc_o : [block_M, dim]      輸出累加器
3for i in range(num_kv_blocks):
4    acc_s = Q @ K[i].T
5    m_prev       = scores_max
6    scores_max   = max(m_prev, rowmax(acc_s))
7    scores_scale = exp(m_prev - scores_max)
8    acc_o *= scores_scale                       # 重新調整先前的輸出
9    acc_s  = exp(acc_s - scores_max)            # 機率
10    acc_o += acc_s @ V[i]

text
1acc_s
text
1acc_o
都希望保留在暫存器中。對於 MHA 或 GQA,這沒問題。但對於 MLA,則不然。

難點在這裡。 MLA 的 Head 維度很大:Query 和 Key 是 576 寬(包含 512 寬的「nope」部分(無位置編碼)加上 64 寬的「rope」部分),Value 則是 512。因此

text
1acc_o
text
1[block_M, 512]
,並且必須在整個 KV 迴圈中駐留在暫存器中。

現在引入硬體。在 Hopper 上,快速路徑是

text
1wgmma.mma_async
,它將 4 個 Warp(128 個執行緒)綁定成一個 Warpgroup,並要求 M 至少為 64。因此一個 Warpgroup 能擁有的最小 M 為 64,這意味著一個 Warpgroup 將持有 64 × 512 的累加器。這對於單個 Warpgroup 的暫存器檔案來說太大了。它會發生溢出(spill),效能隨之崩潰。

MLA decode in TileLang — 將 acc_o 分割給兩個 Warpgroup。WG0 和 WG1 各自計算 Q·K^T (策略=FullCol),透過 S_shared 交換它們的計分結果,然後各自將 P·V 的列條(column slab)計算到 acc_o_L / acc_o_R 中。所有的簿記工作(acc_s 形狀、S_shared 形狀、Q·K 分割)都是由您標註的 FullCol 策略透過佈局推斷導出的。

解決方案是將輸出分割到兩個 Warpgroup 上。 您不能將 M 縮小到 64 以下,所以剩下的軸只有維度(dim)。使用兩個 Warpgroup:WG0 擁有

text
1acc_o[:, :256]
,WG1 擁有
text
1acc_o[:, 256:]
。現在每個 Warpgroup 持有一個 64 × 256 的累加器,這就裝得下了。然而,這帶來了第二個問題:P @ V 步驟(使用
text
1policy=FullCol
,每個 Warpgroup 產生輸出的一個列條)需要完整的
text
1acc_s
,但在 Q @ K 中,每個 Warpgroup 自然只計算了一半。解決方案是共享記憶體交換。在 Q @ K 期間,每個 Warpgroup 將其一半的
text
1acc_s
寫入共享記憶體,並讀取另一個 Warpgroup 的另一半,這樣兩者都持有了完整的
text
1acc_s
,從而都能計算出各自的輸出條。上圖正是如此:分割計分,透過
text
1S_shared
交換,分割輸出。

在 CuTe 中,您需要手動編寫佈局、Swizzle、Tensor Core 對齊以及生產者/消費者同步來完成此操作。在這裡能壓縮到約 80 行的原因是佈局推斷。

讓我們分解佈局推斷的作用。 您在

text
1T.gemm
呼叫上標註意圖,它會為您在程式中傳遞約束:

  1. P @ V 上的
    text
    1policy=FullCol
    意味著每個 Warpgroup 需要完整的
    text
    1acc_s
    ,因此
    text
    1acc_s = [block_M, block_N]
  2. 這會回溯到分級緩衝區,因此
    text
    1T.copy(S_shared, acc_s)
    中的
    text
    1S_shared
    也是
    text
    1[block_M, block_N]
  3. 並向前推導到 Q @ K:使用
    text
    1FullCol
    時,每個 Warpgroup 的計分條為
    text
    1[block_M, block_N/2]

關鍵洞察在於您從未親手撰寫任何這些形狀。您選擇 Warp 策略並編寫數學公式;形狀、Swizzle 佈局以及 Warp 專用的生產者/消費者程式碼都是推斷的結果。

Kernel 架構。 在 MLA 解碼中,Query 分割為「nope」部分(Q,維度 512)和「rope」部分(Q_pe,維度 64),壓縮後的潛在變數同時作為 K 和 V。因此計分是兩個 GEMM 的總和,輸出則是另一個。內部迴圈看起來像這樣(代表性架構,非精確程式碼 — 請參考

text
1example_mla_decode.py
):

plaintext
1# acc_s = Q_nope @ KV^T + Q_rope @ K_pe^T
2T.gemm(Q_shared,    KV_shared,   acc_s, transpose_B=True,
3       policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
4T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True,
5       policy=T.GemmWarpPolicy.FullCol)
6
7# 線上 Softmax
8T.copy(scores_max, scores_max_prev)
9T.fill(scores_max, -T.infinity(accum_dtype))
10T.reduce_max(acc_s, scores_max, dim=1, clear=False)
11# ... exp, 使用 scores_scale 重新調整 acc_o, 歸約到 logsum 中 ...
12
13# acc_o += P @ V  (V 與潛在 KV 相同)
14T.copy(acc_s, acc_s_cast)
15T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)

一旦

text
1FullCol
策略強迫每個 Warpgroup 擁有完整的
text
1acc_s
,兩個 Warpgroup 之間的
text
1S_shared
交換就會由推斷自動插入。

優點:最佳化僅需一行。 這是 TileLang 的回報——整個效能工具包都是單行指令,複雜的底層轉換由系統為您處理。

  • 用於 L2 重用的執行緒區塊 Swizzling:
    text
    1T.use_swizzle(panel_size, order="row")
  • 用於避免 Bank Conflict 的共享記憶體 Swizzling:
    text
    1T.annotate_layout({S_shared: T.layout.make_swizzled_layout(S_shared)})
    — XOR 風格的位址重映射,使並發存取分散到各個 Bank,而不是序列化。
  • Warp 專用化: 您編寫一個簡單的腳本,它會被降低為生產者 Warpgroup(TMA 載入)加上消費者 Warpgroup,並自動產生所有的
    text
    1mbarrier
    同步。這些都不會出現在您的程式碼中。
  • 管線化:
    text
    1T.Pipelined(range, num_stages)
    將載入與計算重疊——級數越多,重疊越多,但共享記憶體佔用也越大,這是一個可調參數。
  • Split-KV (FlashDecoding 風格): 當 Batch 很小且 SM 空閒時,將 KV 上下文分割到多個 SM 並合併。這只是一個
    text
    1num_split
    參數加上一個合併 Kernel 的事。

因此,真正困難的推理——暫存器預算對應 M≥64 的底限、Warp 間的權責劃分、共享記憶體交換——您只需透過選擇策略和撰寫數學公式來表達。原本在 CuTe 中需要數百行脆弱程式碼才能實現的功能,這裡只需透過推斷和程式碼產生。這就是賣點,也是 MLA 案例最具說服力的地方。

我們自己的案例:AtlasCloud 的 RMSNorm 外掛

最後一個例子是我們在 AtlasCloud 自己的生產 Kernel 之一,用於 H100/H200 上的 Wan 影片生成 VAE。這是 TileLang 另一個擅長點的絕佳範例:以簡潔的外掛(drop-in)方式,覆蓋手動編寫的 Kernel 無法達到的配置。

背景。 我們已經發佈了一個手動編寫的融合 RMSNorm + SiLU Kernel。它很快,並且針對模型配置中使用的隱藏維度

text
1D ∈ {96, 192, 384}
進行了編譯。一個較新的配置需要如
text
1{160, 256, 320, 512, 640, 1024}
的通道寬度,因此在該配置上手動編寫的快速路徑無法執行。我們編寫了一個 TileLang 外掛來填補這個缺口。

TileLang Kernel。 一個具有相同介面(BTHWC 輸入/輸出,相同數學,相同 epsilon)且支援任何 32 倍數 C 的外掛。兩趟(two-pass)處理,完全合併,FP32 累加器:

plaintext
1@T.prim_func
2def main(X:     T.Tensor((M, C), dtype),      # M = B*T*H*W 行
3         gamma: T.Tensor((C,),  dtype),
4         Y:     T.Tensor((M, C), dtype)):
5    with T.Kernel(T.ceildiv(M, BLOCK_M), threads=128) as bm:
6        X_chunk = T.alloc_shared((BLOCK_M, BLOCK_C), dtype)
7        ss      = T.alloc_fragment((BLOCK_M,), accum_dtype)   # FP32 平方和
8        # 第 1 趟:遍歷 C 的 BLOCK_C 區塊,以 FP32 累加平方和
9        # rinv = rsqrt(ss / C + 1e-5)
10        # 第 2 趟:重新載入 X,y = silu(x * gamma * rinv),回寫

text
1BLOCK_C
根據 C 的值設為 128/64/32,以遵守 TMA
text
1boxDim ≤ 256
的限制,FP32 累加器確保平方和不會在 FP16 中溢出。調度器會在路徑可用時保留手動編寫的路徑,只有在必要時才會回退:

plaintext
1_ATLAS_SUPPORTED_D = (96, 192, 384)
2
3def rms_silu_dispatch(x, gamma, out):
4    if x.shape[-1] in _ATLAS_SUPPORTED_D:
5        atlas_rms_norm_silu(x, gamma, out=out)        # 保留手動編寫的路徑
6    else:
7        tilelang_rms_silu_bthwc(x, gamma, out=out)    # 填補缺口

收穫。 全都是優勢,這是一個真正的外掛——相同的介面、相同的數學、相同的 epsilon,因此它可以無縫插入現有的調度邏輯中,無需更改呼叫位置。

內容增益
先前不支援的配置0 → 1 — 現在可以執行(關鍵突破)
Attention-block RMSNorm 對比取代的急切式 PyTorch Norm42 μs → 20 μs (~2 倍快)
生產解析度下的端到端 VAE (720×1280, 21 幀)~1.79× 編碼, ~1.78× 解碼

第一行是重點所在:TileLang 讓我們能夠為先前沒有快速路徑的模型配置提供服務,且無需動到已經為其他配置優化好的手動 Kernel。只需一個用 Python 編寫的外掛,整個模型路徑就從「崩潰」變成了「生產可用」。

TileLang 的優勢

  • 無需編寫 CUTLASS/CuTe,即可精細控制區塊劃分、管線級數和 Warp 分割。
  • 處理結構複雜、對佈局敏感的 Kernel:GEMM 變體、FlashAttention 系列、MLA、線性注意力、反量化融合 GEMM、MoE 路由。
  • 覆蓋您的手動 Kernel 無法觸及的算子或配置(隱藏維度超出定義範圍、非典型佈局)——並且在此過程中擊敗了急切式的備援。
  • 跨後端(NVIDIA / AMD / 廠商分支)統一的 Kernel 實作。
  • 整個最佳化工具包只需單次呼叫 —
    text
    1T.use_swizzle
    text
    1T.annotate_layout
    text
    1T.Pipelined
    、Warp 專用化、Split-KV — 底層轉換由系統處理。

總結

TileLang 最酷的地方在於,困難的推理過程保留在您的腦中,而不是大量的樣板程式碼中。您決定如何劃分 Warp 間的工作、緩衝區存放在哪裡、管線跑多深——然後佈局推斷和 Warp 專用化會將其轉化為暫存器佈局、Swizzle 和生產者/消費者舞步,而這些在 CuTe 中通常需要數百行程式碼。您選擇策略並編寫數學公式。這就是它的賣點,也是為什麼一個 80 行的 MLA Kernel 能與手動編寫的 CUTLASS Kernel 並肩作戰的原因。

最新模型

一個 API,暢享全模態 AI。

探索全部模型

Join our Discord community

Join the Discord community for the latest model updates, prompts, and support.