Skip to content

[None][feat] AutoDeploy: Add the Triton kernel for MLA#12664

Merged
nvchenghaoz merged 24 commits into
NVIDIA:mainfrom
nv-auto-deploy:chenghao/triton_mla_clean
Apr 8, 2026
Merged

[None][feat] AutoDeploy: Add the Triton kernel for MLA#12664
nvchenghaoz merged 24 commits into
NVIDIA:mainfrom
nv-auto-deploy:chenghao/triton_mla_clean

Conversation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator

@nvchenghaoz nvchenghaoz commented Apr 1, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Triton-based Multi-head Latent Attention (MLA) backend supporting both single-token and multi-token attention computation
    • Updated AutoDeploy configuration to utilize the new Triton MLA implementation with CUDA graph-based compilation
  • Refactoring

    • Streamlined attention cache handling logic and improved memory management for attention operations

bmarimuthu-nv and others added 18 commits March 23, 2026 10:04
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
Signed-off-by: Balamurugan Marimuthu <[email protected]>
…tral-Small-4-119B

Introduces _triton_mla_decode and _triton_mla_prefill replacing the torch
reference implementation:

Decode (full launcher vs torch_backend):
  B=1  kv=2048: 113 µs vs 230 µs  (2.0x)
  B=8  kv=512:  120 µs vs 1709 µs (14x)
  B=32 kv=512:   92 µs vs 6753 µs (73x)

Prefill (full launcher vs torch_backend):
  T=128:  146 µs vs 277 µs (1.9x)
  T=512:  163 µs vs 328 µs (2.0x)
  T=2048: 690 µs vs 2140 µs (3.1x)

Key optimizations:
- HEAD_BLOCK tiling: share KV cache loads across heads (iter 2)
- Split-K parallelism for small-batch decode (iters 26-51)
- CUDA graph compat: static dispatch via mla_cache.shape[1] (iter 57)
- TP correctness: cap head_block = min(head_block, num_heads) (iter 59)
- Prefill fast path: eliminate repeat_interleave/index_select overhead (iters 60-61)

Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
…x_seq_len=4096, chunked prefill, cuda graph batch sizes)

Signed-off-by: Chenghao Zhang <[email protected]>
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner April 1, 2026 17:52
@nvchenghaoz nvchenghaoz requested a review from QiJune April 1, 2026 17:52
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 1, 2026

📝 Walkthrough

Walkthrough

This PR introduces a new Triton-based MLA (Multi-Head Latent Attention) backend alongside modifications to the existing torch-based implementation. Changes include: a new Triton custom op with decode and prefill kernels, removal of the optional output tensor parameter from torch MLA, module exports expansion, configuration updates, and comprehensive test coverage for the new Triton backend.

Changes

Cohort / File(s) Summary
Configuration Update
examples/auto_deploy/model_registry/configs/mistral_small_4_119b.yaml
Switched compilation from torch-simple to torch-cudagraph and updated insert_cached_mla_attention backend from torch_mla to triton_mla.
Module Exports
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
Extended public API to export TritonMLAAttention and triton_cached_mla_with_cache from the new triton_mla module.
Torch Backend Modifications
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py
Removed mutates_args=("mla_cache",) from custom op registration, eliminated optional out parameter and associated output-handling logic, simplified token bookkeeping and reshape operations for both decode and prefill paths.
Triton MLA Implementation
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py
Added new Triton backend with multihead and split-K MLA kernels, triton_cached_mla_with_cache custom op registration, decode and prefill path implementations, and TritonMLAAttention descriptor for runtime dispatch.
Triton MLA Tests
tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_triton_mla_op.py
Added comprehensive test suite comparing Triton MLA outputs against torch reference across decode and prefill scenarios, including cache consistency validation and descriptor registration verification.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Query/Key/Value<br/>Input Tensors
    participant Dispatch as MLA Dispatcher<br/>(triton_cached_mla_with_cache)
    participant Decode as Triton Decode<br/>Path (s==1)
    participant Prefill as Triton Prefill<br/>Path (s>1)
    participant Cache as MLA Cache<br/>(Unpaged)
    participant Kernel as Triton MLA<br/>Kernel
    participant Proj as Output/Value<br/>Projection

    Client->>Dispatch: q_nope, q_pe, compressed_kv, kpe
    Dispatch->>Dispatch: Derive v_head_dim & scale
    alt Decode Path
        Dispatch->>Decode: s==1
        Decode->>Cache: Read cached KV at position
        Decode->>Kernel: Weight absorption & attention compute
        Kernel->>Decode: Partial attention output
        Decode->>Cache: Update cache via tensor indexing
        Decode->>Proj: Apply output/value projection
        Proj->>Client: [B, 1, N, v_head_dim]
    else Prefill Path
        Dispatch->>Prefill: s>1
        alt Single Sequence
            Prefill->>Kernel: Fast contiguous path
        else Multi-Sequence
            Prefill->>Prefill: Build per-token metadata<br/>(slots, positions, lengths)
            Prefill->>Kernel: Batched attention compute
        end
        Kernel->>Prefill: Attention output
        Prefill->>Proj: Apply output/value projection
        Proj->>Client: [B, S, N, v_head_dim]
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.16% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is missing entirely (contains only '@coderabbitai summary' token), lacking required sections such as Description, Test Coverage, and PR Checklist completion. Add a comprehensive description explaining what Triton MLA kernel does, why it was added, which tests validate the changes, and confirm PR checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately describes the main feature being added: a Triton kernel implementation for MLA (multi-head linear attention) in AutoDeploy, which aligns with the substantial changes across multiple files and the new triton_mla.py implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py Outdated
Comment thread tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_triton_mla_op.py Outdated
Comment thread tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_triton_mla_op.py Outdated
Comment thread tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_triton_mla_op.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (1)

1-18: ⚠️ Potential issue | 🟠 Major

Add the NVIDIA SPDX header to this modified Python file.

This file is changed in the PR, but it still starts directly with a docstring instead of the required NVIDIA copyright/license header.

As per coding guidelines, All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py` around
lines 1 - 18, Add the required NVIDIA SPDX copyright header to the top of the
modified Python module (torch_backend_mla.py) before the module docstring;
include the appropriate SPDX identifier and the year of latest modification and
the NVIDIA copyright line consistent with project policy. Ensure the header
appears above the existing docstring so tools and legal scans pick it up, and
keep the rest of the file (including symbols torch_cached_mla_with_cache and
TorchBackendMLAAttention) unchanged.
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py (1)

1-17: ⚠️ Potential issue | 🟠 Major

Add the NVIDIA SPDX header before exporting the new backend.

This modified module still lacks the required NVIDIA copyright/license header at the top.

As per coding guidelines, All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py` around lines 1 -
17, Add the required NVIDIA SPDX copyright/license header at the top of this
module (above the module docstring) including the appropriate year of the latest
meaningful modification and the SPDX identifier; ensure it follows the project's
standard header format and appears before any imports or the docstring so files
like this one that export TorchBackendMLAAttention, FlashInferMLAAttention,
TritonMLAAttention, torch_mla, torch_backend_mla_with_cache,
flashinfer_mla_with_cache, and triton_cached_mla_with_cache include the mandated
NVIDIA header.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Line 325: The custom op registration for
"auto_deploy::torch_cached_mla_with_cache" currently hides side effects by using
mutates_args=(); update the decorator to declare the cache as mutated (e.g.
mutates_args=("mla_cache",)) so PyTorch's export/functionalization sees the
write; make the change on the `@torch.library.custom_op` line for the op
registered in torch_backend_mla.py (the function used by _update_mla_cache and
_torch_mla_generate_with_absorption modifies mla_cache), and ensure the op
signature/argument name matches the declared "mla_cache" identifier.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py`:
- Line 568: The code currently uses floor-division of num_heads by head_block
(e.g., computing num_head_blocks = num_heads // head_block) which drops tail
heads and leaves weighted_kv/out uninitialized; change these launches to use
ceil-division (num_head_blocks = (num_heads + head_block - 1) // head_block) and
add an in-kernel guard that checks computed head index < num_heads before
writing (so kernels skip work for masked tail lanes), or alternatively enforce
head_block divisibility by num_heads; update every occurrence that launches on
head_block (e.g., where head_block, num_heads, weighted_kv, out are used) to use
the ceil division and add the head-index bound check in the kernel entry to
ensure no heads are lost.
- Line 869: The custom op decorator for
auto_deploy::triton_cached_mla_with_cache incorrectly marks the op as pure;
update the `@torch.library.custom_op` on the function implementing
triton_cached_mla_with_cache to declare that mla_cache is mutated (e.g., change
mutates_args=() to include "mla_cache"), because helper functions
_triton_mla_decode and _triton_mla_prefill perform indexed assignments into
mla_cache and the op must be treated as mutating that argument.
- Around line 804-806: The current code forces a device-to-host sync by calling
seq_lengths.sum().item() (total_tokens) inside _triton_mla_prefill; instead,
change _triton_mla_prefill signature to accept an int num_prefill_tokens and
remove the seq_lengths.sum().item() call (and any use of total_tokens computed
from .item()), compute num_prefill_tokens at the wrapper call site using
BatchInfo.get_absorbed_info() and pass that int into _triton_mla_prefill; update
all call sites (the wrapper that currently prepares seq_lengths/seq_start_l
before calling _triton_mla_prefill) to pass the new num_prefill_tokens parameter
so no device-to-host sync occurs inside seq_lengths/seq_start_l handling.

---

Outside diff comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py`:
- Around line 1-17: Add the required NVIDIA SPDX copyright/license header at the
top of this module (above the module docstring) including the appropriate year
of the latest meaningful modification and the SPDX identifier; ensure it follows
the project's standard header format and appears before any imports or the
docstring so files like this one that export TorchBackendMLAAttention,
FlashInferMLAAttention, TritonMLAAttention, torch_mla,
torch_backend_mla_with_cache, flashinfer_mla_with_cache, and
triton_cached_mla_with_cache include the mandated NVIDIA header.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Around line 1-18: Add the required NVIDIA SPDX copyright header to the top of
the modified Python module (torch_backend_mla.py) before the module docstring;
include the appropriate SPDX identifier and the year of latest modification and
the NVIDIA copyright line consistent with project policy. Ensure the header
appears above the existing docstring so tools and legal scans pick it up, and
keep the rest of the file (including symbols torch_cached_mla_with_cache and
TorchBackendMLAAttention) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 88c4c24e-c001-44f9-94e4-75ad798d8358

📥 Commits

Reviewing files that changed from the base of the PR and between fd49221 and 1ada357.

📒 Files selected for processing (5)
  • examples/auto_deploy/model_registry/configs/mistral_small_4_119b.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_triton_mla_op.py

Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla.py Outdated
Signed-off-by: Chenghao Zhang <[email protected]>
@nvchenghaoz nvchenghaoz changed the title [None][feat] AutoDeploy: Add the triton kernel for mla [None][feat] AutoDeploy Add the triton kernel for mla Apr 1, 2026
@nvchenghaoz nvchenghaoz changed the title [None][feat] AutoDeploy Add the triton kernel for mla [None][feat] AutoDeploy: Add the Triton kernel for MLA Apr 1, 2026
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1"

@suyoggupta
Copy link
Copy Markdown
Collaborator

@suyoggupta suyoggupta self-requested a review April 1, 2026 21:10
Copy link
Copy Markdown
Collaborator

@suyoggupta suyoggupta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock. Please add a smoke test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41258 [ run ] completed with state SUCCESS. Commit: abfc784
/LLM/main/L0_MergeRequest_PR pipeline #32216 (Partly Tested) completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

Approving to unblock. Please add a smoke test

It seems even for the smoke test, it needs to be added to the LLM_ROOT. Will create a task to track that.

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41484 [ run ] triggered by Bot. Commit: 7d2cac2 Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41493 [ run ] triggered by Bot. Commit: 7d2cac2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41493 [ run ] completed with state SUCCESS. Commit: 7d2cac2
/LLM/main/L0_MergeRequest_PR pipeline #32413 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41721 [ run ] triggered by Bot. Commit: d9a6ad2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41721 [ run ] completed with state SUCCESS. Commit: d9a6ad2
/LLM/main/L0_MergeRequest_PR pipeline #32623 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Chenghao Zhang <[email protected]>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41990 [ run ] triggered by Bot. Commit: f253181 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41990 [ run ] completed with state SUCCESS. Commit: f253181
/LLM/main/L0_MergeRequest_PR pipeline #32841 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42163 [ run ] triggered by Bot. Commit: f253181 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42163 [ run ] completed with state FAILURE. Commit: f253181
/LLM/main/L0_MergeRequest_PR pipeline #32991 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42172 [ run ] triggered by Bot. Commit: f253181 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42172 [ run ] completed with state SUCCESS. Commit: f253181
/LLM/main/L0_MergeRequest_PR pipeline #33000 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42182 [ run ] triggered by Bot. Commit: f253181 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42182 [ run ] completed with state SUCCESS. Commit: f253181
/LLM/main/L0_MergeRequest_PR pipeline #33007 completed with status: 'SUCCESS'

CI Report

Link to invocation

@nvchenghaoz nvchenghaoz merged commit 2fe39c1 into NVIDIA:main Apr 8, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants