[TRTLLM-9521][feat] Unfuse indexer.wk from attention GEMM for DS-V3.2 NVFP4#11989
Conversation
📝 WalkthroughWalkthroughThe changes refactor DeepSeek V3 attention handling by removing the extra head dimension from the kv projection output, eliminating the post-load weight fusion method, and moving indexer computation from the weight loading phase to explicit separate calls during the forward pass. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip You can disable the changed files summary in the walkthrough.Disable the |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 1486-1488: The unconditional call to indexer.wk (producing
indexer_k) is executed even on the short-MHA skip path; change the code so
indexer_k is only computed when indexer routing is actually used (i.e., when not
skipping short-MHA or when num_generations > 0). Concretely, wrap or move the
indexer_k = self.indexer.wk(hidden_states) invocation behind the same condition
that decides routing (referencing use_short_mha_for_ctx, num_generations and the
short-MHA skip branch) so the projection is avoided on the short-MHA bypass
path; ensure downstream uses of indexer_k remain valid by computing it only in
the branch that requires it.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 23695228-c908-4822-b182-e7a2c6bc9ffe
📒 Files selected for processing (2)
tensorrt_llm/_torch/models/modeling_deepseekv3.pytensorrt_llm/_torch/modules/attention.py
…V3.2 NVFP4 Signed-off-by: peihengh <[email protected]>
88a637a to
9c95ca0
Compare
pengbowang-nv
left a comment
There was a problem hiding this comment.
Attention part LGTM
|
/bot run |
|
PR_Github #38292 [ run ] triggered by Bot. Commit: |
|
PR_Github #38292 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38305 [ run ] triggered by Bot. Commit: |
|
PR_Github #38305 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38328 [ run ] triggered by Bot. Commit: |
|
PR_Github #38328 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38340 [ run ] triggered by Bot. Commit: |
|
PR_Github #38340 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38458 [ run ] triggered by Bot. Commit: |
|
PR_Github #38458 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38486 [ run ] triggered by Bot. Commit: |
|
PR_Github #38486 [ run ] completed with state |
Signed-off-by: peihengh <[email protected]>
|
/bot run |
|
PR_Github #38898 [ run ] triggered by Bot. Commit: |
|
PR_Github #38898 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38919 [ run ] triggered by Bot. Commit: |
|
PR_Github #38919 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #39104 [ run ] triggered by Bot. Commit: |
|
PR_Github #39104 [ run ] completed with state |
|
/bot run --stage-list "DGX_B200-4_GPUs-PyTorch-Post-Merge-1, DGX_B200-4_GPUs-PyTorch-Post-Merge-2" |
|
PR_Github #39133 [ run ] triggered by Bot. Commit: |
|
PR_Github #39133 [ run ] completed with state |
|
/bot skip --comment "already has a green CI 6 hours ago" |
|
PR_Github #39160 [ skip ] triggered by Bot. Commit: |
|
PR_Github #39160 [ skip ] completed with state |
… NVFP4 (NVIDIA#11989) Signed-off-by: peihengh <[email protected]>
… NVFP4 (NVIDIA#11989) Signed-off-by: peihengh <[email protected]>
Summary by CodeRabbit
Bug Fixes
Refactor
Description
With model weight loading shared between DeepSeek R1/V3/V3.2, in existing main,
indexer.wkwas fused intokv_a_proj_with_mqaat load time viapost_load_weights, forcing it to share the fused module's NVFP4 quantization. This can cause accuracy issues when serving DS V3.2 NVFP4 quantization with attention also quantized to NVFP4 becauseindexer.wkis more sensitive to quantization.This change unfuses
indexer.wkso it loads through the standard weight loading path but retains whatever precision the checkpoint provides. This is also a prerequisite for the later task, which will convert the indexer ops to TF32 and fuseindexer.wkwithindexer.w_projinto a single TF32 kernel for a complete support of attention quantization in NVFP4 for DS V3.2.Changes
indexer.head_dimfromkv_a_proj_with_mqaoutput dimension (2240 → 2112 for V3.2)indexer_kvia separateindexer.wk(hidden_states)call instead of extracting it from the fused 4-way splitpost_load_weights(no longer needed since weights aren't fused)q_a_projdtype check withis_litesince V3-Lite has noq_a_projBased on Binghan's closed draft PR #9776.
Test Coverage
test_sparse_mla_forward.py— 28 passed on B200, 17 passed on H100test_short_seq_mha.py— all passed on both architecturestest_nvfp4_attn_multi_gpuswith new accuracy section: MMLU 87.9%, GSM8K 95.2%DeepSeek-V3.2-NVFP4-FP4attnPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.