Escrevendo Kernels de Alto Desempenho em TileLang, de GEMM a MLA

O TileLang permite que você escreva kernels de GPU em Python com controle explícito sobre tiles, pipelines e warps. De uma simples GEMM até a decodificação MLA do DeepSeek.

Escrevendo Kernels de Alto Desempenho em TileLang, de GEMM a MLA

Se você escreve kernels de GPU, você se encontra em algum ponto de um espectro. Em uma ponta está o Triton: rápido de escrever, mas o compilador toma a maioria das decisões de layout e memória compartilhada para você. Na outra ponta está o CUTLASS / CuTe: controle total, ao custo de muito maquinário de templates. O TileLang fica no meio. Você escreve em Python, mas diz explicitamente o que vive na memória compartilhada, como o pipeline é organizado e como os warps dividem o trabalho — e uma passagem de inferência de layout preenche o restante.

Neste post, abordaremos o modelo mental, escreveremos um GEMM e, em seguida, construiremos um kernel de produção real: o decode de MLA do DeepSeek, onde as decisões interessantes realmente aparecem. O objetivo não é ser exaustivo. É mostrar como você pensa sobre blocos (tiles) e onde o TileLang silenciosamente faz as partes difíceis para você. Terminaremos com uma história mais típica de produção — um kernel onde a vitória não foi a velocidade.

O modelo mental

Aqui está a ideia completa em três pontos.

  • Um tile é um objeto de primeira classe. Um pedaço de dados com formato definido (block_M × block_K, por exemplo) é de propriedade e operado por um bloco de threads, um warp ou uma thread. Você para de pensar puramente no nível de bloco de threads como faz no Triton, e para de gerenciar manualmente threads individuais como faz no CUDA.
  • Você mesmo coloca os buffers na hierarquia de memória. Você declara o que vai para a memória compartilhada (T.alloc_shared), o que vai para registradores (T.alloc_fragment) e o que é local à thread. Esta é a maior diferença em relação ao Triton, que esconde a alocação de memória compartilhada e o staging dentro do compilador.
  • O compilador infere o mapeamento de threads. Uma vez que você definiu onde um tile vive e qual operação é executada nele (uma cópia, um gemm, uma redução), uma passagem de inferência de layout o paraleliza entre as threads e resolve os layouts de registradores e de memória compartilhada. Você pode substituir isso quando precisar, mas na maioria das vezes não precisará. Esta passagem é o recurso de sustentação — quando chegarmos ao MLA, você verá o porquê.

Se você vem do Triton, aqui está o mapeamento aproximado.

 TritonTileLang
Granularidadebloco de threads + vetorização implícitatile (bloco / warp / thread)
Memória comp.gerenciada pelo compiladorexplicit alloc_shared + copy
Layouto compilador decideinferido, mas você pode anotar
Pipeliningtl.range + compiladorexplícito T.Pipelined(num_stages=)
Tensor Coretl.dotT.gemm com política de warp selecionável
BackendsNVIDIA (principalmente) / AMDNVIDIA / AMD / CPU / WebGPU / CuTeDSL, além de forks Ascend e MUSA

Em resumo: se você quer controle preciso sobre bloqueio (blocking), profundidade de pipeline e particionamento de warp sem escrever CUTLASS, o TileLang é o ponto ideal. Para operações simples elemento a elemento ou fusões leves, o Triton ainda é mais rápido de recorrer.

Configuração

plaintext
1conda create -n tilelang python=3.10 -y
2conda activate tilelang
3pip install tilelang                 # wheel pré-compilado, caminho mais fácil

Se você pretende mexer nas passagens do compilador, compile a partir do código-fonte (você precisará de um toolchain LLVM/CUDA local):

plaintext
1git clone --recursive https://github.com/tile-ai/tilelang.git
2cd tilelang && pip install -r requirements-dev.txt
3pip install -e . -v --no-build-isolation

Vamos escrever um GEMM

Começaremos com o kernel que todos começam: C = ReLU(A @ B). É pequeno, mas toca em cada primitiva que importa — buffers explícitos, cópia paralela, pipeline de software, a chamada de Tensor Core e um swizzle de L2.

plaintext
1import tilelang
2import tilelang.language as T
3import torch
4
5@tilelang.jit
6def matmul(M, N, K, block_M, block_N, block_K,
7           dtype="float16", accum_dtype="float"):
8
9    @T.prim_func
10    def matmul_relu_kernel(
11        A: T.Tensor((M, K), dtype),
12        B: T.Tensor((K, N), dtype),
13        C: T.Tensor((M, N), dtype),
14    ):
15        # dims da grade: (#blocos ao longo de N, #blocos ao longo de M); 128 threads por bloco
16        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M),
17                      threads=128) as (bx, by):
18
19            # Diga explicitamente onde cada tile vive.
20            A_shared = T.alloc_shared((block_M, block_K), dtype)         # memória compartilhada
21            B_shared = T.alloc_shared((block_K, block_N), dtype)
22            C_local  = T.alloc_fragment((block_M, block_N), accum_dtype) # acumulador em registradores
23
24            T.use_swizzle(panel_size=4, order="col")   # opcional: melhor reuso em L2
25            T.clear(C_local)                           # limpa o acumulador
26
27            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
28                T.copy(A[by * block_M, ko * block_K], A_shared)   # global -> shared
29                T.copy(B[ko * block_K, bx * block_N], B_shared)
30                T.gemm(A_shared, B_shared, C_local)               # MMA ao nível de tile
31
32            for i, j in T.Parallel(block_M, block_N):             # ReLU fundido
33                C_local[i, j] = T.max(C_local[i, j], 0)
34
35            T.copy(C_local, C[by * block_M, bx * block_N])        # escreve de volta
36
37    return matmul_relu_kernel
38
39
40M = N = K = 1024
41kernel = matmul(M, N, K, block_M=128, block_N=128, block_K=64)
42a = torch.randn(M, K, device="cuda", dtype=torch.float16)
43b = torch.randn(K, N, device="cuda", dtype=torch.float16)
44c = torch.empty(M, N, device="cuda", dtype=torch.float16)
45kernel(a, b, c)
46
47torch.testing.assert_close(c, torch.relu(a @ b), rtol=1e-2, atol=1e-2)
48print("gemm ok")

Aqui está o que cada peça faz:

  • Três buffers, três níveis. A_shared e B_shared vivem na memória compartilhada; C_local vive em registradores. Acumulador em registradores, operandos organizados via memória compartilhada — essa é a receita padrão de GEMM, exceto que aqui você a escreve. Essa é toda a diferença do Triton em uma linha.
  • T.copy é um atalho para uma cópia paralela. Ela se expande para um movimento do estilo T.Parallel, e o compilador deriva uma transferência global→shared vetorizada e coalescente a partir dela. Quando a cópia fica dentro de T.Pipelined, ela se torna automaticamente cp.async.
  • T.Pipelined(extent, num_stages=N) é um pipeline de software. num_stages=3 significa buffer triplo — enquanto você computa o tile-K ko, os carregamentos para ko+1 e ko+2 já estão em andamento. No Triton, isso é uma flag de compilação; aqui é apenas o loop, o que é mais fácil de raciocinar.
  • T.gemm(A, B, C) é o matmul ao nível de tile. Ele reduz para CuTe/MMA na NVIDIA e o intrínseco correspondente na AMD. Ele também aceita transpose_A / transpose_B e um policy=T.GemmWarpPolicy.* que controla como os warps dividem o tile de saída. Guarde esse argumento de política — ele é tudo quando chegarmos ao MLA.
  • T.use_swizzle reordena como os blocos de threads são escalonados para que blocos adjacentes no L2 sejam executados próximos no tempo. Geralmente resulta em alguns por cento a mais de largura de banda livre.

A figura abaixo mapeia tudo isso no hardware. Vale a pena ler comparando com o código, porque os pontos rotulados são exatamente onde o TileLang lhe dá o controle que o Triton mantém para si mesmo.

GEMM no TileLang — você mesmo coloca cada buffer na hierarquia. A_shared / B_shared ficam na memória compartilhada, C_local acumula em registradores entre os warps W0–W3, e o pipeline do loop K (num_stages=3) sobrepõe pré-buscas cp.async com a computação gemm atual.

Algumas primitivas que você usará

Você pode escrever a maioria dos kernels com um vocabulário pequeno.

  • Alocação: T.alloc_shared, T.alloc_fragment (registradores), T.alloc_local.
  • Mover e inicializar: T.copy(src, dst) entre quaisquer dois níveis; T.clear, T.fill.
  • Computação: T.gemm(...); T.Parallel(d0, d1, ...) para loops elemento a elemento (este é o ponto de entrada para inferência de layout); T.reduce_max / T.reduce_sum; matemática escalar como T.exp, T.exp2, T.max, T.infinity.
  • Escalonamento: T.Pipelined(extent, num_stages=), T.use_swizzle(...), T.annotate_layout(...) quando você precisar de um layout específico (evitar conflito de banco, swizzle personalizado).
  • Forma dinâmica: M = T.dynamic("m") para que você não precise recompilar por forma (é chamado de T.symbolic em algumas versões).

Verificando seu trabalho

Duas coisas que você desejará frequentemente. Para ver o que o compilador realmente emitiu:

plaintext
1print(kernel.get_kernel_source())     # CUDA / HIP gerado

E para medir o tempo:

plaintext
1profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
2print(f"latência: {profiler.do_bench()} ms")

T.print(buf) imprime um tile de dentro do kernel, e o examples/plot_layout do repositório desenha o layout de memória, o que é útil quando você está perseguindo um conflito de banco ou verificando um swizzle.

Agora um real: MLA decode

O GEMM mostra a mecânica. Este próximo mostra por que elas importam. Vamos percorrer o kernel de decode MLA (Multi-Head Latent Attention) do DeepSeek, porque é o exemplo mais claro do TileLang fazendo valer a pena. A referência do TileLang chega a aproximadamente o desempenho do FlashMLA no H100 (benchmarked no batch 64/128 em fp16, confortavelmente à frente do Triton e FlashInfer) em cerca de 80 linhas de Python. A questão interessante é como, porque a parte difícil do MLA não é a matemática — é a pressão sobre os registradores.

Vamos revisar o loop que todos conhecem. Todo kernel da família FlashAttention tem a mesma forma. Por bloco de query, você faz streaming sobre blocos de key/value e mantém um máximo corrente e um denominador, para que a matriz de scores completa nunca vá para a memória:

plaintext
1# acc_s : [block_M, block_N]  scores para este bloco KV
2# acc_o : [block_M, dim]      acumulador de saída
3for i in range(num_kv_blocks):
4    acc_s = Q @ K[i].T
5    m_prev       = scores_max
6    scores_max   = max(m_prev, rowmax(acc_s))
7    scores_scale = exp(m_prev - scores_max)
8    acc_o *= scores_scale                       # redimensiona a saída anterior
9    acc_s  = exp(acc_s - scores_max)            # probabilidades
10    acc_o += acc_s @ V[i]

Tanto acc_s quanto acc_o querem ficar nos registradores. Para MHA ou GQA, isso é ótimo. Para MLA, não é.

Aqui é onde fica difícil. As dimensões de head do MLA são grandes: query e key têm 512 de largura (uma parte "nope" de 512 sem codificação posicional, mais uma parte "rope" de 64), e value é 512. Então acc_o = [block_M, 512], e precisa permanecer residente nos registradores durante todo o loop KV.

Agora, traga o hardware. No Hopper, o caminho rápido é o wgmma.mma_async, que vincula 4 warps (128 threads) em um warpgroup e requer um M mínimo de 64. Portanto, o menor M que um warpgroup pode ter é 64, o que significa que um warpgroup estaria mantendo um acumulador de 64 × 512. Isso é grande demais para o arquivo de registradores de um único warpgroup. Ele sofre spill e o desempenho cai drasticamente.

MLA decode no TileLang — dividindo acc_o entre dois warpgroups. WG0 e WG1 computam Q·K^T (policy=FullCol), trocam suas metades de score através de S_shared, e então cada um computa sua fatia de coluna de P·V em acc_o_L / acc_o_R. Todo o bookkeeping (forma de acc_s, forma de S_shared, divisão de Q·K) é derivado da inferência de layout a partir da política FullCol que você anotou.

A solução é dividir a saída entre dois warpgroups. Você não pode reduzir M abaixo de 64, então o único eixo restante é dim. Use dois warpgroups: WG0 possui acc_o[:, :256], WG1 possui acc_o[:, 256:]. Agora cada um mantém um acumulador de 64 × 256, que cabe. Isso cria um segundo problema, no entanto: o passo P @ V (com policy=FullCol, cada warpgroup produzindo uma fatia de coluna da saída) precisa do acc_s completo, mas em Q @ K cada warpgroup naturalmente computou apenas metade dele. A resolução é uma troca de memória compartilhada. Durante Q @ K, cada warpgroup escreve sua metade de acc_s na memória compartilhada e lê a metade do outro warpgroup, então, posteriormente, ambos possuem o acc_s completo e podem cada um computar sua fatia de acc_o. O diagrama acima é exatamente isso: dividir os scores, trocar através de S_shared, dividir a saída.

No CuTe, você escreveria manualmente os layouts, os swizzles, o alinhamento de Tensor Core e a sincronização produtor/consumidor para conseguir isso. A razão pela qual isso colapsa para ~80 linhas aqui é a inferência de layout.

Vamos analisar o que a inferência de layout faz. Você anota a intenção nas chamadas T.gemm, e ela propaga as restrições através do programa para você:

  1. policy=FullCol em P @ V significa que cada warpgroup precisa do acc_s completo, então acc_s = [block_M, block_N].
  2. Isso se propaga de volta para o buffer de staging, então S_shared em T.copy(S_shared, acc_s) também é [block_M, block_N].
  3. E para a frente em Q @ K: com FullCol, a fatia de score de cada warpgroup é [block_M, block_N/2].

O insight chave é que você nunca escreve nenhuma dessas formas. Você escolhe a política de warp e escreve a matemática; as formas, os layouts swizzled e o código produtor/consumidor especializado para warp surgem todos da inferência.

O esqueleto do kernel. No MLA decode, a query se divide em uma parte "nope" (Q, dim 512) e uma parte "rope" (Q_pe, dim 64), e o latente comprimido serve tanto como K quanto V. Então o score é uma soma de dois GEMMs, e a saída é mais um. O loop interno parece com isso (um esqueleto representativo, não exato linha por linha — veja example_mla_decode.py):

plaintext
1# acc_s = Q_nope @ KV^T + Q_rope @ K_pe^T
2T.gemm(Q_shared,    KV_shared,   acc_s, transpose_B=True,
3       policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
4T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True,
5       policy=T.GemmWarpPolicy.FullCol)
6
7# softmax online
8T.copy(scores_max, scores_max_prev)
9T.fill(scores_max, -T.infinity(accum_dtype))
10T.reduce_max(acc_s, scores_max, dim=1, clear=False)
11# ... exp, redimensiona acc_o por scores_scale, reduce_sum em logsum ...
12
13# acc_o += P @ V  (V é o mesmo KV latente)
14T.copy(acc_s, acc_s_cast)
15T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)

A troca de S_shared entre os dois warpgroups é a parte que a inferência insere para você, uma vez que as políticas FullCol forçam o acc_s a ser completo por warpgroup.

A parte boa: as otimizações são de uma linha cada. É aqui que o TileLang compensa — todo o kit de ferramentas de desempenho é composto de comandos de uma linha, e o lowering bagunçado é tratado para você.

  • Swizzling de threadblock para reuso em L2: T.use_swizzle(panel_size, order="row").
  • Swizzling de memória compartilhada para conflitos de banco: T.annotate_layout({S_shared: T.layout.make_swizzled_layout(S_shared)}) — remapeamento de endereços estilo XOR para que acessos concorrentes se espalhem pelos bancos em vez de serializar.
  • Especialização de Warp: você escreve um script simples, e ele é reduzido para um warpgroup produtor (carregamentos TMA) mais warpgroups consumidores, com toda a sincronização mbarrier gerada. Nada disso aparece no seu código.
  • Pipelining: T.Pipelined(range, num_stages) sobrepõe carregamentos com computação — mais estágios, mais sobreposição, mas mais memória compartilhada, então é uma configuração.
  • Split-KV (estilo FlashDecoding): quando o batch é pequeno e os SMs estão ociosos, divida o contexto KV entre os SMs e combine. É um parâmetro num_split mais um kernel de combinação.

Portanto, o raciocínio genuinamente difícil — orçamento de registradores contra o piso de M≥64, quem possui o quê entre os warpgroups, a troca de memória compartilhada — você expressa escolhendo uma política e escrevendo a matemática. Tudo o que seriam centenas de linhas frágeis no CuTe é inferência e codegen. Essa é a proposta, e o MLA é onde ela é mais convincente.

Um caso nosso: um RMSNorm drop-in na AtlasCloud

O último exemplo é um de nossos kernels de produção na AtlasCloud, do VAE de geração de vídeo Wan no H100/H200. É uma ótima ilustração da outra coisa em que o TileLang é excelente: cobrir uma configuração que um kernel ajustado manualmente não consegue alcançar, com um drop-in limpo.

A configuração. Já enviamos um kernel RMSNorm + SiLU fundido ajustado manualmente. É rápido e compilado para as dimensões ocultas D ∈ {96, 192, 384} que uma configuração de modelo usa. Uma configuração mais nova precisa de larguras de canal como {160, 256, 320, 512, 640, 1024}, então nessa configuração o caminho rápido ajustado manualmente não pode ser executado. Escrevemos um drop-in em TileLang para cobrir exatamente essa lacuna.

O kernel TileLang. Um drop-in com a mesma interface (BTHWC in/out, mesma matemática, mesmo eps) que suporta qualquer C que seja múltiplo de 32. Duas passagens, totalmente coalescido, acumulador FP32:

plaintext
1@T.prim_func
2def main(X:     T.Tensor((M, C), dtype),      # M = B*T*H*W linhas
3         gamma: T.Tensor((C,),  dtype),
4         Y:     T.Tensor((M, C), dtype)):
5    with T.Kernel(T.ceildiv(M, BLOCK_M), threads=128) as bm:
6        X_chunk = T.alloc_shared((BLOCK_M, BLOCK_C), dtype)
7        ss      = T.alloc_fragment((BLOCK_M,), accum_dtype)   # soma de quadrados em FP32
8        # passagem 1: loop sobre C em pedaços de BLOCK_C, acumula soma de quadrados em FP32
9        # rinv = rsqrt(ss / C + 1e-5)
10        # passagem 2: recarrega X, y = silu(x * gamma * rinv), escreve de volta

BLOCK_C é 128/64/32 dependendo de C, para respeitar o limite TMA boxDim ≤ 256, e o acumulador FP32 mantém a soma dos quadrados longe de estourar em FP16. O dispatch mantém o caminho ajustado manualmente onde ele funciona e só recorre ao fallback quando precisa:

plaintext
1_ATLAS_SUPPORTED_D = (96, 192, 384)
2
3def rms_silu_dispatch(x, gamma, out):
4    if x.shape[-1] in _ATLAS_SUPPORTED_D:
5        atlas_rms_norm_silu(x, gamma, out=out)        # mantém o caminho ajustado manualmente
6    else:
7        tilelang_rms_silu_bthwc(x, gamma, out=out)    # cobre a lacuna

O que ganhamos. Tudo é ganho, e é um verdadeiro drop-in — mesma interface, mesma matemática, mesmo eps, então ele se encaixa atrás do dispatch existente sem alterações no call-site.

O quêGanho
Configuração anteriormente não suportada0 → 1 — agora funciona (a vitória principal)
RMSNorm em bloco de atenção vs a norma eager do PyTorch que substituiu42 μs → 20 μs (~2× mais rápido)
VAE ponta a ponta na resolução de produção (720×1280, 21 frames)~1.79× encode, ~1.78× decode

A primeira linha é o ponto real: o TileLang nos permitiu atender a uma configuração de modelo que anteriormente não tinha nenhum caminho rápido, sem tocar no kernel ajustado manualmente que já funciona para a outra configuração. Um drop-in, escrito em Python, e todo um caminho de modelo passou de "falha" para "em produção".

Onde o TileLang brilha

  • Controle preciso sobre bloqueio, estágios de pipeline e particionamento de warp, sem escrever CUTLASS/CuTe.
  • Kernels estruturalmente complexos e sensíveis ao layout: variantes de GEMM, família FlashAttention, MLA, atenção linear, GEMM quantizado fundido com dequant, roteamento MoE.
  • Cobrir uma operação ou configuração que seus kernels ajustados manualmente não alcançam (uma dimensão oculta fora do conjunto instanciado, um layout incomum) — e vencendo um fallback eager enquanto você está nisso.
  • Um corpo de kernel em todos os backends (NVIDIA / AMD / forks de fornecedores).
  • Todo o kit de ferramentas de otimização é uma chamada de cada vez — T.use_swizzle, T.annotate_layout, T.Pipelined, especialização de warp, split-KV — com o lowering tratado para você.

Conclusão

A parte legal do TileLang é que o raciocínio difícil permanece na sua cabeça, não no boilerplate. Você decide como dividir o trabalho entre os warps, onde os buffers vivem e quão profundo o pipeline funciona — e então a inferência de layout e a especialização de warp transformam isso nos layouts de registradores, nos swizzles e na dança produtor/consumidor que, de outra forma, seriam centenas de linhas de CuTe. Você escolhe uma política e escreve a matemática. Essa é toda a proposta, e é por isso que um kernel MLA de 80 linhas pode ficar ao lado de um CUTLASS ajustado manualmente.

Modelos recentes

Mais de 300 Modelos, Comece Agora,

Explorar Todos os Modelos

Join our Discord community

Join the Discord community for the latest model updates, prompts, and support.