Escrevendo Kernels de Alto Desempenho em TileLang, de GEMM a MLA

O TileLang permite que você escreva kernels de GPU em Python com controle explícito sobre tiles, pipelines e warps. De uma simples GEMM até a decodificação MLA do DeepSeek.

Escrevendo Kernels de Alto Desempenho em TileLang, de GEMM a MLA

Se você escreve kernels de GPU, você se encontra em algum ponto de um espectro. Em uma ponta está o Triton: rápido de escrever, mas o compilador toma a maioria das decisões de layout e memória compartilhada para você. Na outra ponta está o CUTLASS / CuTe: controle total, ao custo de muito maquinário de templates. O TileLang fica no meio. Você escreve em Python, mas diz explicitamente o que vive na memória compartilhada, como o pipeline é organizado e como os warps dividem o trabalho — e uma passagem de inferência de layout preenche o restante.

Neste post, abordaremos o modelo mental, escreveremos um GEMM e, em seguida, construiremos um kernel de produção real: o decode de MLA do DeepSeek, onde as decisões interessantes realmente aparecem. O objetivo não é ser exaustivo. É mostrar como você pensa sobre blocos (tiles) e onde o TileLang silenciosamente faz as partes difíceis para você. Terminaremos com uma história mais típica de produção — um kernel onde a vitória não foi a velocidade.

O modelo mental

Aqui está a ideia completa em três pontos.

  • Um tile é um objeto de primeira classe. Um pedaço de dados com formato definido (block_M × block_K, por exemplo) é de propriedade e operado por um bloco de threads, um warp ou uma thread. Você para de pensar puramente no nível de bloco de threads como faz no Triton, e para de gerenciar manualmente threads individuais como faz no CUDA.
  • Você mesmo coloca os buffers na hierarquia de memória. Você declara o que vai para a memória compartilhada (T.alloc_shared), o que vai para registradores (T.alloc_fragment) e o que é local à thread. Esta é a maior diferença em relação ao Triton, que esconde a alocação de memória compartilhada e o staging dentro do compilador.
  • O compilador infere o mapeamento de threads. Uma vez que você definiu onde um tile vive e qual operação é executada nele (uma cópia, um gemm, uma redução), uma passagem de inferência de layout o paraleliza entre as threads e resolve os layouts de registradores e de memória compartilhada. Você pode substituir isso quando precisar, mas na maioria das vezes não precisará. Esta passagem é o recurso de sustentação — quando chegarmos ao MLA, você verá o porquê.

Se você vem do Triton, aqui está o mapeamento aproximado.

 TritonTileLang
Granularidadebloco de threads + vetorização implícitatile (bloco / warp / thread)
Memória comp.gerenciada pelo compiladorexplicit alloc_shared + copy
Layouto compilador decideinferido, mas você pode anotar
Pipeliningtl.range + compiladorexplícito T.Pipelined(num_stages=)
Tensor Coretl.dotT.gemm com política de warp selecionável
BackendsNVIDIA (principalmente) / AMDNVIDIA / AMD / CPU / WebGPU / CuTeDSL, além de forks Ascend e MUSA

Em resumo: se você quer controle preciso sobre bloqueio (blocking), profundidade de pipeline e particionamento de warp sem escrever CUTLASS, o TileLang é o ponto ideal. Para operações simples elemento a elemento ou fusões leves, o Triton ainda é mais rápido de recorrer.

Configuração

plaintext
1conda create -n tilelang python=3.10 -y
2conda activate tilelang
3pip install tilelang                 # wheel pré-compilado, caminho mais fácil

Se você pretende mexer nas passagens do compilador, compile a partir do código-fonte (você precisará de um toolchain LLVM/CUDA local):

plaintext
1git clone --recursive https://github.com/tile-ai/tilelang.git
2cd tilelang && pip install -r requirements-dev.txt
3pip install -e . -v --no-build-isolation

Vamos escrever um GEMM

Começaremos com o kernel que todos começam: C = ReLU(A @ B). É pequeno, mas toca em cada primitiva que importa — buffers explícitos, cópia paralela, pipeline de software, a chamada de Tensor Core e um swizzle de L2.

plaintext
1import tilelang
2import tilelang.language as T
3import torch
4
5@tilelang.jit
6def matmul(M, N, K, block_M, block_N, block_K,
7           dtype="float16", accum_dtype="float"):
8
9    @T.prim_func
10    def matmul_relu_kernel(
11        A: T.Tensor((M, K), dtype),
12        B: T.Tensor((K, N), dtype),
13        C: T.Tensor((M, N), dtype),
14    ):
15        # dims da grade: (#blocos ao longo de N, #blocos ao longo de M); 128 threads por bloco
16        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M),
17                      threads=128) as (bx, by):
18
19            # Diga explicitamente onde cada tile vive.
20            A_shared = T.alloc_shared((block_M, block_K), dtype)         # memória compartilhada
21            B_shared = T.alloc_shared((block_K, block_N), dtype)
22            C_local  = T.alloc_fragment((block_M, block_N), accum_dtype) # acumulador em registradores
23
24            T.use_swizzle(panel_size=4, order="col")   # opcional: melhor reuso em L2
25            T.clear(C_local)                           # limpa o acumulador
26
27            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
28                T.copy(A[by * block_M, ko * block_K], A_shared)   # global -> shared
29                T.copy(B[ko * block_K, bx * block_N], B_shared)
30                T.gemm(A_shared, B_shared, C_local)               # MMA ao nível de tile
31
32            for i, j in T.Parallel(block_M, block_N):             # ReLU fundido
33                C_local[i, j] = T.max(C_local[i, j], 0)
34
35            T.copy(C_local, C[by * block_M, bx * block_N])        # escreve de volta
36
37    return matmul_relu_kernel
38
39
40M = N = K = 1024
41kernel = matmul(M, N, K, block_M=128, block_N=128, block_K=64)
42a = torch.randn(M, K, device="cuda", dtype=torch.float16)
43b = torch.randn(K, N, device="cuda", dtype=torch.float16)
44c = torch.empty(M, N, device="cuda", dtype=torch.float16)
45kernel(a, b, c)
46
47torch.testing.assert_close(c, torch.relu(a @ b), rtol=1e-2, atol=1e-2)
48print("gemm ok")

Aqui está o que cada peça faz:

  • Três buffers, três níveis. A_shared e B_shared vivem na memória compartilhada; C_local vive em registradores. Acumulador em registradores, operandos organizados via memória compartilhada — essa é a receita padrão de GEMM, exceto que aqui você a escreve. Essa é toda a diferença do Triton em uma linha.
  • T.copy é um atalho para uma cópia paralela. Ela se expande para um movimento do estilo T.Parallel, e o compilador deriva uma transferência global→shared vetorizada e coalescente a partir dela. Quando a cópia fica dentro de T.Pipelined, ela se torna automaticamente cp.async.
  • T.Pipelined(extent, num_stages=N) é um pipeline de software. num_stages=3 significa buffer triplo — enquanto você computa o tile-K ko, os carregamentos para ko+1 e ko+2 já estão em andamento. No Triton, isso é uma flag de compilação; aqui é apenas o loop, o que é mais fácil de raciocinar.
  • T.gemm(A, B, C) é o matmul ao nível de tile. Ele reduz para CuTe/MMA na NVIDIA e o intrínseco correspondente na AMD. Ele também aceita transpose_A / transpose_B e um policy=T.GemmWarpPolicy.* que controla como os warps dividem o tile de saída. Guarde esse argumento de política — ele é tudo quando chegarmos ao MLA.
  • T.use_swizzle reordena como os blocos de threads são escalonados para que blocos adjacentes no L2 sejam executados próximos no tempo. Geralmente resulta em alguns por cento a mais de largura de banda livre.

A figura abaixo mapeia tudo isso no hardware. Vale a pena ler comparando com o código, porque os pontos rotulados são exatamente onde o TileLang lhe dá o controle que o Triton mantém para si mesmo.

GEMM no TileLang — você mesmo coloca cada buffer na hierarquia. A_shared / B_shared ficam na memória compartilhada, C_local acumula em registradores entre os warps W0–W3, e o pipeline do loop K (num_stages=3) sobrepõe pré-buscas cp.async com a computação gemm atual.

Algumas primitivas que você usará

Você pode escrever a maioria dos kernels com um vocabulário pequeno.

  • Alocação: T.alloc_shared, T.alloc_fragment (registradores), T.alloc_local.
  • Mover e inicializar: T.copy(src, dst) entre quaisquer dois níveis; T.clear, T.fill.
  • Computação: T.gemm(...); T.Parallel(d0, d1, ...) para loops elemento a elemento (este é o ponto de entrada para inferência de layout); T.reduce_max / T.reduce_sum; matemática escalar como T.exp, T.exp2, T.max, T.infinity.
  • Escalonamento: T.Pipelined(extent, num_stages=), T.use_swizzle(...), T.annotate_layout(...) quando você precisar de um layout específico (evitar conflito de banco, swizzle personalizado).
  • Forma dinâmica: M = T.dynamic("m") para que você não precise recompilar por forma (é chamado de T.symbolic em algumas versões).

Verificando seu trabalho

Duas coisas que você desejará frequentemente. Para ver o que o compilador realmente emitiu:

plaintext
1print(kernel.get_kernel_source())     # CUDA / HIP gerado

E para medir o tempo:

plaintext
1profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
2print(f"latência: {profiler.do_bench()} ms")

T.print(buf) imprime um tile de dentro do kernel, e o examples/plot_layout do repositório desenha o layout de memória, o que é útil quando você está perseguindo um conflito de banco ou verificando um swizzle.

Agora um real: MLA decode

O GEMM mostra a mecânica. Este próximo mostra por que elas importam. Vamos percorrer o kernel de decode MLA (Multi-Head Latent Attention) do DeepSeek, porque é o exemplo mais claro do TileLang fazendo valer a pena. A referência do TileLang chega a aproximadamente o desempenho do FlashMLA no H100 (benchmarked no batch 64/128 em fp16, confortavelmente à frente do Triton e FlashInfer) em cerca de 80 linhas de Python. A questão interessante é como, porque a parte difícil do MLA não é a matemática — é a pressão sobre os registradores.

Vamos revisar o loop que todos conhecem. Todo kernel da família FlashAttention tem a mesma forma. Por bloco de query, você faz streaming sobre blocos de key/value e mantém um máximo corrente e um denominador, para que a matriz de scores completa nunca vá para a memória:

plaintext
1# acc_s : [block_M, block_N]  scores para este bloco KV
2# acc_o : [block_M, dim]      acumulador de saída
3for i in range(num_kv_blocks):
4    acc_s = Q @ K[i].T
5    m_prev       = scores_max
6    scores_max   = max(m_prev, rowmax(acc_s))
7    scores_scale = exp(m_prev - scores_max)
8    acc_o *= scores_scale                       # redimensiona a saída anterior
9    acc_s  = exp(acc_s - scores_max)            # probabilidades
10    acc_o += acc_s @ V[i]

Tanto acc_s quanto acc_o querem ficar nos registradores. Para MHA ou GQA, isso é ótimo. Para MLA, não é.

Aqui é onde fica difícil. As dimensões de head do MLA são grandes: query e key têm 512 de largura (uma parte "nope" de 512 sem codificação posicional, mais uma parte "rope" de 64), e value é 512. Então acc_o = [block_M, 512], e precisa permanecer residente nos registradores durante todo o loop KV.

Agora, traga o hardware. No Hopper, o caminho rápido é o wgmma.mma_async, que vincula 4 warps (128 threads) em um warpgroup e requer um M mínimo de 64. Portanto, o menor M que um warpgroup pode ter é 64, o que significa que um warpgroup estaria mantendo um acumulador de 64 × 512. Isso é grande demais para o arquivo de registradores de um único warpgroup. Ele sofre spill e o desempenho cai drasticamente.

MLA decode no TileLang — dividindo acc_o entre dois warpgroups. WG0 e WG1 computam Q·K^T (policy=FullCol), trocam suas metades de score através de S_shared, e então cada um computa sua fatia de coluna de P·V em acc_o_L / acc_o_R. Todo o bookkeeping (forma de acc_s, forma de S_shared, divisão de Q·K) é derivado da inferência de layout a partir da política FullCol que você anotou.

A solução é dividir a saída entre dois warpgroups. Você não pode reduzir M abaixo de 64, então o único eixo restante é dim. Use dois warpgroups: WG0 possui acc_o[:, :256], WG1 possui acc_o[:, 256:]. Agora cada um mantém um acumulador de 64 × 256, que cabe. Isso cria um segundo problema, no entanto: o passo P @ V (com policy=FullCol, cada warpgroup produzindo uma fatia de coluna da saída) precisa do acc_s completo, mas em Q @ K cada warpgroup naturalmente computou apenas metade dele. A resolução é uma troca de memória compartilhada. Durante Q @ K, cada warpgroup escreve sua metade de acc_s na memória compartilhada e lê a metade do outro warpgroup, então, posteriormente, ambos possuem o acc_s completo e podem cada um computar sua fatia de acc_o. O diagrama acima é exatamente isso: dividir os scores, trocar através de S_shared, dividir a saída.

No CuTe, você escreveria manualmente os layouts, os swizzles, o alinhamento de Tensor Core e a sincronização produtor/consumidor para conseguir isso. A razão pela qual isso colapsa para ~80 linhas aqui é a inferência de layout.

Vamos analisar o que a inferência de layout faz. Você anota a intenção nas chamadas T.gemm, e ela propaga as restrições através do programa para você:

  1. policy=FullCol em P @ V significa que cada warpgroup precisa do acc_s completo, então acc_s = [block_M, block_N].
  2. Isso se propaga de volta para o buffer de staging, então S_shared em T.copy(S_shared, acc_s) também é [block_M, block_N].
  3. E para a frente em Q @ K: com FullCol, a fatia de score de cada warpgroup é [block_M, block_N/2].

O insight chave é que você nunca escreve nenhuma dessas formas. Você escolhe a política de warp e escreve a matemática; as formas, os layouts swizzled e o código produtor/consumidor especializado para warp surgem todos da inferência.

O esqueleto do kernel. No MLA decode, a query se divide em uma parte "nope" (Q, dim 512) e uma parte "rope" (Q_pe, dim 64), e o latente comprimido serve tanto como K quanto V. Então o score é uma soma de dois GEMMs, e a saída é mais um. O loop interno parece com isso (um esqueleto representativo, não exato linha por linha — veja example_mla_decode.py):

plaintext
1# acc_s = Q_nope @ KV^T + Q_rope @ K_pe^T
2T.gemm(Q_shared,    KV_shared,   acc_s, transpose_B=True,
3       policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
4T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True,
5       policy=T.GemmWarpPolicy.FullCol)
6
7# softmax online
8T.copy(scores_max, scores_max_prev)
9T.fill(scores_max, -T.infinity(accum_dtype))
10T.reduce_max(acc_s, scores_max, dim=1, clear=False)
11# ... exp, redimensiona acc_o por scores_scale, reduce_sum em logsum ...
12
13# acc_o += P @ V  (V é o mesmo KV latente)
14T.copy(acc_s, acc_s_cast)
15T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)

A troca de S_shared entre os dois warpgroups é a parte que a inferência insere para você, uma vez que as políticas FullCol forçam o acc_s a ser completo por warpgroup.

A parte boa: as otimizações são de uma linha cada. É aqui que o TileLang compensa — todo o kit de ferramentas de desempenho é composto de comandos de uma linha, e o lowering bagunçado é tratado para você.

  • Swizzling de threadblock para reuso em L2: T.use_swizzle(panel_size, order="row").
  • Swizzling de memória compartilhada para conflitos de banco: T.annotate_layout({S_shared: T.layout.make_swizzled_layout(S_shared)}) — remapeamento de endereços estilo XOR para que acessos concorrentes se espalhem pelos bancos em vez de serializar.
  • Especialização de Warp: você escreve um script simples, e ele é reduzido para um warpgroup produtor (carregamentos TMA) mais warpgroups consumidores, com toda a sincronização mbarrier gerada. Nada disso aparece no seu código.
  • Pipelining: T.Pipelined(range, num_stages) sobrepõe carregamentos com computação — mais estágios, mais sobreposição, mas mais memória compartilhada, então é uma configuração.
  • Split-KV (estilo FlashDecoding): quando o batch é pequeno e os SMs estão ociosos, divida o contexto KV entre os SMs e combine. É um parâmetro num_split mais um kernel de combinação.

Portanto, o raciocínio genuinamente difícil — orçamento de registradores contra o piso de M≥64, quem possui o quê entre os warpgroups, a troca de memória compartilhada — você expressa escolhendo uma política e escrevendo a matemática. Tudo o que seriam centenas de linhas frágeis no CuTe é inferência e codegen. Essa é a proposta, e o MLA é onde ela é mais convincente.

Um caso nosso: um RMSNorm drop-in na AtlasCloud

O último exemplo é um de nossos kernels de produção na AtlasCloud, do VAE de geração de vídeo Wan no H100/H200. É uma ótima ilustração da outra coisa em que o TileLang é excelente: cobrir uma configuração que um kernel ajustado manualmente não consegue alcançar, com um drop-in limpo.

A configuração. Já enviamos um kernel RMSNorm + SiLU fundido ajustado manualmente. É rápido e compilado para as dimensões ocultas D ∈ {96, 192, 384} que uma configuração de modelo usa. Uma configuração mais nova precisa de larguras de canal como {160, 256, 320, 512, 640, 1024}, então nessa configuração o caminho rápido ajustado manualmente não pode ser executado. Escrevemos um drop-in em TileLang para cobrir exatamente essa lacuna.

O kernel TileLang. Um drop-in com a mesma interface (BTHWC in/out, mesma matemática, mesmo eps) que suporta qualquer C que seja múltiplo de 32. Duas passagens, totalmente coalescido, acumulador FP32:

plaintext
1@T.prim_func
2def main(X:     T.Tensor((M, C), dtype),      # M = B*T*H*W linhas
3         gamma: T.Tensor((C,),  dtype),
4         Y:     T.Tensor((M, C), dtype)):
5    with T.Kernel(T.ceildiv(M, BLOCK_M), threads=128) as bm:
6        X_chunk = T.alloc_shared((BLOCK_M, BLOCK_C), dtype)
7        ss      = T.alloc_fragment((BLOCK_M,), accum_dtype)   # soma de quadrados em FP32
8        # passagem 1: loop sobre C em pedaços de BLOCK_C, acumula soma de quadrados em FP32
9        # rinv = rsqrt(ss / C + 1e-5)
10        # passagem 2: recarrega X, y = silu(x * gamma * rinv), escreve de volta

BLOCK_C é 128/64/32 dependendo de C, para respeitar o limite TMA boxDim ≤ 256, e o acumulador FP32 mantém a soma dos quadrados longe de estourar em FP16. O dispatch mantém o caminho ajustado manualmente onde ele funciona e só recorre ao fallback quando precisa:

plaintext
1_ATLAS_SUPPORTED_D = (96, 192, 384)
2
3def rms_silu_dispatch(x, gamma, out):
4    if x.shape[-1] in _ATLAS_SUPPORTED_D:
5        atlas_rms_norm_silu(x, gamma, out=out)        # mantém o caminho ajustado manualmente
6    else:
7        tilelang_rms_silu_bthwc(x, gamma, out=out)    # cobre a lacuna

O que ganhamos. Tudo é ganho, e é um verdadeiro drop-in — mesma interface, mesma matemática, mesmo eps, então ele se encaixa atrás do dispatch existente sem alterações no call-site.

O quêGanho
Configuração anteriormente não suportada0 → 1 — agora funciona (a vitória principal)
RMSNorm em bloco de atenção vs a norma eager do PyTorch que substituiu42 μs → 20 μs (~2× mais rápido)
VAE ponta a ponta na resolução de produção (720×1280, 21 frames)~1.79× encode, ~1.78× decode

A primeira linha é o ponto real: o TileLang nos permitiu atender a uma configuração de modelo que anteriormente não tinha nenhum caminho rápido, sem tocar no kernel ajustado manualmente que já funciona para a outra configuração. Um drop-in, escrito em Python, e todo um caminho de modelo passou de "falha" para "em produção".

Onde o TileLang brilha

  • Controle preciso sobre bloqueio, estágios de pipeline e particionamento de warp, sem escrever CUTLASS/CuTe.
  • Kernels estruturalmente complexos e sensíveis ao layout: variantes de GEMM, família FlashAttention, MLA, atenção linear, GEMM quantizado fundido com dequant, roteamento MoE.
  • Cobrir uma operação ou configuração que seus kernels ajustados manualmente não alcançam (uma dimensão oculta fora do conjunto instanciado, um layout incomum) — e vencendo um fallback eager enquanto você está nisso.
  • Um corpo de kernel em todos os backends (NVIDIA / AMD / forks de fornecedores).
  • Todo o kit de ferramentas de otimização é uma chamada de cada vez — T.use_swizzle, T.annotate_layout, T.Pipelined, especialização de warp, split-KV — com o lowering tratado para você.

Conclusão

A parte legal do TileLang é que o raciocínio difícil permanece na sua cabeça, não no boilerplate. Você decide como dividir o trabalho entre os warps, onde os buffers vivem e quão profundo o pipeline funciona — e então a inferência de layout e a especialização de warp transformam isso nos layouts de registradores, nos swizzles e na dança produtor/consumidor que, de outra forma, seriam centenas de linhas de CuTe. Você escolhe uma política e escreve a matemática. Essa é toda a proposta, e é por isso que um kernel MLA de 80 linhas pode ficar ao lado de um CUTLASS ajustado manualmente.

Modelos recentes

Uma API para toda a IA de mídia.

Explorar Todos os Modelos

Join our Discord community

Join the Discord community for the latest model updates, prompts, and support.

Escrevendo Kernels de Alto Desempenho em TileLang, de GEMM a MLA