[None][feat] AutoDeploy: Add the Triton kernel for MLA#12664
Conversation
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
…tral-Small-4-119B Introduces _triton_mla_decode and _triton_mla_prefill replacing the torch reference implementation: Decode (full launcher vs torch_backend): B=1 kv=2048: 113 µs vs 230 µs (2.0x) B=8 kv=512: 120 µs vs 1709 µs (14x) B=32 kv=512: 92 µs vs 6753 µs (73x) Prefill (full launcher vs torch_backend): T=128: 146 µs vs 277 µs (1.9x) T=512: 163 µs vs 328 µs (2.0x) T=2048: 690 µs vs 2140 µs (3.1x) Key optimizations: - HEAD_BLOCK tiling: share KV cache loads across heads (iter 2) - Split-K parallelism for small-batch decode (iters 26-51) - CUDA graph compat: static dispatch via mla_cache.shape[1] (iter 57) - TP correctness: cap head_block = min(head_block, num_heads) (iter 59) - Prefill fast path: eliminate repeat_interleave/index_select overhead (iters 60-61) Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
…x_seq_len=4096, chunked prefill, cuda graph batch sizes) Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
📝 WalkthroughWalkthroughThis PR introduces a new Triton-based MLA (Multi-Head Latent Attention) backend alongside modifications to the existing torch-based implementation. Changes include: a new Triton custom op with decode and prefill kernels, removal of the optional output tensor parameter from torch MLA, module exports expansion, configuration updates, and comprehensive test coverage for the new Triton backend. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Query/Key/Value<br/>Input Tensors
participant Dispatch as MLA Dispatcher<br/>(triton_cached_mla_with_cache)
participant Decode as Triton Decode<br/>Path (s==1)
participant Prefill as Triton Prefill<br/>Path (s>1)
participant Cache as MLA Cache<br/>(Unpaged)
participant Kernel as Triton MLA<br/>Kernel
participant Proj as Output/Value<br/>Projection
Client->>Dispatch: q_nope, q_pe, compressed_kv, kpe
Dispatch->>Dispatch: Derive v_head_dim & scale
alt Decode Path
Dispatch->>Decode: s==1
Decode->>Cache: Read cached KV at position
Decode->>Kernel: Weight absorption & attention compute
Kernel->>Decode: Partial attention output
Decode->>Cache: Update cache via tensor indexing
Decode->>Proj: Apply output/value projection
Proj->>Client: [B, 1, N, v_head_dim]
else Prefill Path
Dispatch->>Prefill: s>1
alt Single Sequence
Prefill->>Kernel: Fast contiguous path
else Multi-Sequence
Prefill->>Prefill: Build per-token metadata<br/>(slots, positions, lengths)
Prefill->>Kernel: Batched attention compute
end
Kernel->>Prefill: Attention output
Prefill->>Proj: Apply output/value projection
Proj->>Client: [B, S, N, v_head_dim]
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (1)
1-18:⚠️ Potential issue | 🟠 MajorAdd the NVIDIA SPDX header to this modified Python file.
This file is changed in the PR, but it still starts directly with a docstring instead of the required NVIDIA copyright/license header.
As per coding guidelines,
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py` around lines 1 - 18, Add the required NVIDIA SPDX copyright header to the top of the modified Python module (torch_backend_mla.py) before the module docstring; include the appropriate SPDX identifier and the year of latest modification and the NVIDIA copyright line consistent with project policy. Ensure the header appears above the existing docstring so tools and legal scans pick it up, and keep the rest of the file (including symbols torch_cached_mla_with_cache and TorchBackendMLAAttention) unchanged.tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py (1)
1-17:⚠️ Potential issue | 🟠 MajorAdd the NVIDIA SPDX header before exporting the new backend.
This modified module still lacks the required NVIDIA copyright/license header at the top.
As per coding guidelines,
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py` around lines 1 - 17, Add the required NVIDIA SPDX copyright/license header at the top of this module (above the module docstring) including the appropriate year of the latest meaningful modification and the SPDX identifier; ensure it follows the project's standard header format and appears before any imports or the docstring so files like this one that export TorchBackendMLAAttention, FlashInferMLAAttention, TritonMLAAttention, torch_mla, torch_backend_mla_with_cache, flashinfer_mla_with_cache, and triton_cached_mla_with_cache include the mandated NVIDIA header.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Line 325: The custom op registration for
"auto_deploy::torch_cached_mla_with_cache" currently hides side effects by using
mutates_args=(); update the decorator to declare the cache as mutated (e.g.
mutates_args=("mla_cache",)) so PyTorch's export/functionalization sees the
write; make the change on the `@torch.library.custom_op` line for the op
registered in torch_backend_mla.py (the function used by _update_mla_cache and
_torch_mla_generate_with_absorption modifies mla_cache), and ensure the op
signature/argument name matches the declared "mla_cache" identifier.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py`:
- Line 568: The code currently uses floor-division of num_heads by head_block
(e.g., computing num_head_blocks = num_heads // head_block) which drops tail
heads and leaves weighted_kv/out uninitialized; change these launches to use
ceil-division (num_head_blocks = (num_heads + head_block - 1) // head_block) and
add an in-kernel guard that checks computed head index < num_heads before
writing (so kernels skip work for masked tail lanes), or alternatively enforce
head_block divisibility by num_heads; update every occurrence that launches on
head_block (e.g., where head_block, num_heads, weighted_kv, out are used) to use
the ceil division and add the head-index bound check in the kernel entry to
ensure no heads are lost.
- Line 869: The custom op decorator for
auto_deploy::triton_cached_mla_with_cache incorrectly marks the op as pure;
update the `@torch.library.custom_op` on the function implementing
triton_cached_mla_with_cache to declare that mla_cache is mutated (e.g., change
mutates_args=() to include "mla_cache"), because helper functions
_triton_mla_decode and _triton_mla_prefill perform indexed assignments into
mla_cache and the op must be treated as mutating that argument.
- Around line 804-806: The current code forces a device-to-host sync by calling
seq_lengths.sum().item() (total_tokens) inside _triton_mla_prefill; instead,
change _triton_mla_prefill signature to accept an int num_prefill_tokens and
remove the seq_lengths.sum().item() call (and any use of total_tokens computed
from .item()), compute num_prefill_tokens at the wrapper call site using
BatchInfo.get_absorbed_info() and pass that int into _triton_mla_prefill; update
all call sites (the wrapper that currently prepares seq_lengths/seq_start_l
before calling _triton_mla_prefill) to pass the new num_prefill_tokens parameter
so no device-to-host sync occurs inside seq_lengths/seq_start_l handling.
---
Outside diff comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py`:
- Around line 1-17: Add the required NVIDIA SPDX copyright/license header at the
top of this module (above the module docstring) including the appropriate year
of the latest meaningful modification and the SPDX identifier; ensure it follows
the project's standard header format and appears before any imports or the
docstring so files like this one that export TorchBackendMLAAttention,
FlashInferMLAAttention, TritonMLAAttention, torch_mla,
torch_backend_mla_with_cache, flashinfer_mla_with_cache, and
triton_cached_mla_with_cache include the mandated NVIDIA header.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Around line 1-18: Add the required NVIDIA SPDX copyright header to the top of
the modified Python module (torch_backend_mla.py) before the module docstring;
include the appropriate SPDX identifier and the year of latest modification and
the NVIDIA copyright line consistent with project policy. Ensure the header
appears above the existing docstring so tools and legal scans pick it up, and
keep the rest of the file (including symbols torch_cached_mla_with_cache and
TorchBackendMLAAttention) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 88c4c24e-c001-44f9-94e4-75ad798d8358
📒 Files selected for processing (5)
examples/auto_deploy/model_registry/configs/mistral_small_4_119b.yamltensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.pytests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_triton_mla_op.py
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1" |
|
we don't necessarily need an acc test, can we add a smoke test https://sourcegraph.com/r/github.com/NVIDIA/TensorRT-LLM/-/blob/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_build_small_single.py |
suyoggupta
left a comment
There was a problem hiding this comment.
Approving to unblock. Please add a smoke test
|
PR_Github #41258 [ run ] completed with state
|
It seems even for the smoke test, it needs to be added to the LLM_ROOT. Will create a task to track that. |
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1" |
|
PR_Github #41484 [ run ] triggered by Bot. Commit: |
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #41493 [ run ] triggered by Bot. Commit: |
|
PR_Github #41493 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #41721 [ run ] triggered by Bot. Commit: |
|
PR_Github #41721 [ run ] completed with state
|
Signed-off-by: Chenghao Zhang <[email protected]>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #41990 [ run ] triggered by Bot. Commit: |
|
PR_Github #41990 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42163 [ run ] triggered by Bot. Commit: |
|
PR_Github #42163 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42172 [ run ] triggered by Bot. Commit: |
|
PR_Github #42172 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42182 [ run ] triggered by Bot. Commit: |
|
PR_Github #42182 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Refactoring