[None][feat] Temporally-Correlated Heuristic-guided Indexer TopK for Sparse Attention#12385
Conversation
📝 WalkthroughWalkthroughThis PR introduces a heuristic top-K decoding mechanism that accelerates top-K selection by reusing previous step's top-K indices as hints. It adds new CUDA kernel implementations, extends the IndexerTopK interface to accept pre-computed indices, integrates with PyTorch operators and DSA sparse attention, and introduces comprehensive distribution-based tests. Changes
Sequence DiagramsequenceDiagram
participant PyTorch as PyTorch Inference
participant DSA as DSA Attention Layer
participant Indexer as Indexer (sparse_attn_indexer)
participant TopK as indexer_topk_decode Op
participant Kernel as CUDA Kernels
PyTorch->>DSA: Generate tokens (decode step)
DSA->>Indexer: Call with use_custom_topk flag
alt enable_heuristic_topk enabled
Indexer->>Indexer: Derive pre_idx from heuristic_prev_topk<br/>(previous layer's stored TopK)
Indexer->>TopK: Call with pre_idx parameter
TopK->>Kernel: Dispatch to launchHeuristicTopKDecode
Kernel->>Kernel: Check if N <= topK (simple case)<br/>else use heuristic path
Kernel-->>TopK: Return top-K indices
TopK-->>Indexer: top-K results
Indexer->>Indexer: Update heuristic_prev_topk<br/>with current layer's topK
else enable_heuristic_topk disabled
Indexer->>TopK: Call with pre_idx=None
TopK->>Kernel: Dispatch to standard path<br/>(insertion/radix/multi-block)
Kernel-->>TopK: Return top-K indices
TopK-->>Indexer: top-K results
end
Indexer-->>DSA: Sparse attention indices
DSA-->>PyTorch: Attention output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/heuristic_topk.cuh`:
- Around line 751-776: The launcher launchHeuristicTopK is unsafe because
heuristicTopKJob/heuristicTopKKernel assumes the compile-time TOP_K (2048) but
launchHeuristicTopK accepts a runtime topK; either enforce and reject mismatched
sizes or propagate the runtime topK into the kernel/job. Fix by validating topK
against TOP_K at the top of launchHeuristicTopK and return cudaErrorInvalidValue
(or similar) if they differ, or modify the kernel/job interfaces
(heuristicTopKKernel and heuristicTopKJob) to accept and use the runtime topK
everywhere (remove compile-time-only assumptions/padding/thresholding) so no
writes occur past the provided outputValues/outputIndices.
- Around line 772-774: The explicit launch with a 64-bit index type
(launchHeuristicTopK<float, int64_t>) corrupts buffers because the kernel launch
reinterprets preIdx/outputIndices as int*/int const* and uses 32-bit element
sizes; fix by making the kernel and launch correctly use the index template
IdxT: either remove the int64_t instantiation and only use 32-bit indices, or
fully templatize the kernel/job on IdxT and change the reinterpret_casts and
buffer pointer types in the launch sites (e.g., heuristicTopKKernel launch and
the calls around the launchHeuristicTopK instantiations at the same block and at
lines ~779-782) so pointers, casts, and writes/reads use reinterpret_cast<IdxT
const*> / reinterpret_cast<IdxT*> and the correct element sizes.
In `@cpp/tensorrt_llm/kernels/heuristicTopKDecode.cu`:
- Around line 85-98: The current process-wide static flag `configured` causes
`cudaFuncSetAttribute(heuristicTopKMultiRowKernel,
cudaFuncAttributeMaxDynamicSharedMemorySize, ...)` to be applied only once
across all GPUs; make this device-aware by replacing the single `configured`
with a per-device guard (e.g., a container keyed by the current device id
obtained via `cudaGetDevice`) or simply remove the guard and call
`cudaFuncSetAttribute` unconditionally when `smemSize > 48u * 1024u && smemSize
<= static_cast<size_t>(maxSmem)`; ensure you check `device` via `cudaGetDevice`
and set the per-device flag after a successful `cudaFuncSetAttribute` for
`heuristicTopKMultiRowKernel` to avoid suppressing opt-in on other GPUs.
In `@cpp/tensorrt_llm/thop/IndexerTopKOp.cpp`:
- Around line 74-85: The pre_idx handling only checks is_cuda() but must also
ensure it lives on the same CUDA device as logits to avoid passing a foreign
device pointer into invokeIndexerTopKDecode; update the pre_idx branch (the
block using preIdxTensor, preIdxPtr, preIdxStride, preIdxCount) to assert the
devices match (compare preIdxTensor.device() with logits.device() or
logits.device().index()) using TORCH_CHECK and a clear error message, so the
kernel always receives a pointer on the same CUDA device as logits.
- Around line 80-81: The TORCH_CHECK in IndexerTopKOp.cpp incorrectly allows
preIdxTensor.size(0) == numRows64 when next_n > 1, which leads
heuristicTopKMultiRowKernel to index hints by rowIdx / next_n and ignore per-row
hints; update the validation around preIdxTensor, next_n and numRows64 so that
when next_n > 1 the only accepted hint shape is preIdxTensor.size(0) * next_n ==
numRows64 (or equivalently preIdxTensor.size(0) == numRows64 / next_n), and only
allow the preIdxTensor.size(0) == numRows64 shorthand when next_n == 1; locate
and change the TORCH_CHECK that currently references preIdxTensor.size(0),
next_n and numRows64 (and any nearby comments) to enforce this stricter
condition so callers cannot silently pass per-row hints that will be ignored by
heuristicTopKMultiRowKernel.
In `@examples/longbench/eval_longbench_v1.py`:
- Around line 389-409: Add upfront argument validation to reject combinations
that require PyTorch when the user selected the TensorRT backend: if
args.backend == "tensorrt" and (args.dsa_sparse is True or args.mtp > 0) then
raise a clear CLI error (e.g., parser.error or raise ValueError) before building
DeepSeekSparseAttentionConfig or MTPDecodingConfig; reference the symbols
args.dsa_sparse, DeepSeekSparseAttentionConfig, args.mtp, and MTPDecodingConfig
so the check runs prior to the blocks that construct those configs and prevents
LLM(...) from failing later.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 533-540: The allocated buffer heuristic_prev_topk (created via
get_empty(...)) is being read/shifted before being seeded; reset newly allocated
or newly-assigned slots to -1 and ensure the shift logic (where staging += 1 is
applied) only increments entries that are >= 0 so padding/sentinel -1 stays as
“no hint”. Update the code paths that touch heuristic_prev_topk (and the same
pattern at the other occurrences) to explicitly initialize new request slots to
-1 after get_empty, and change the increment/shift logic to conditionally add 1
only for non-negative elements (e.g., mask with >=0 or use where/conditional
update) so stale -1 values never turn into candidate 0.
In `@tests/unittest/_torch/thop/parallel/test_indexer_topk.py`:
- Around line 313-367: The issue is that a block was accidentally dedented to
module scope so _run_cute_dsl_topk_test() returns immediately after seed setup,
making the rest of the test body run at import time and creating a duplicate
test name; fix it by re-indenting the displaced block so all code that generates
logits, runs assertions, and parametrized test logic is inside the
_run_cute_dsl_topk_test function (identify by the function name
_run_cute_dsl_topk_test and the seed setup lines), and remove or rename the
stray duplicate test_indexer_topk_decode_dist definition at module scope so
there is no F811 duplicate symbol.
- Around line 834-835: Clamp seq_lens to ensure no sequence is shorter than
next_n before computing row_ends: after calling generate_seq_lens(...) assign
seq_lens = seq_lens.clamp(min=next_n) so that row_ends = seq_lens[row_indices] -
next_n + next_n_offset + 1 cannot produce non-positive values; this prevents
generate_pre_idx from receiving non-positive valid_len and blowing up. Ensure
you update the variable used by subsequent logic (row_ends, generate_pre_idx) so
the clamp takes effect.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 948bc074-35b6-4176-80ee-3ff5990b6509
📒 Files selected for processing (12)
cpp/tensorrt_llm/kernels/IndexerTopK.hcpp/tensorrt_llm/kernels/heuristicTopKDecode.cucpp/tensorrt_llm/kernels/heuristicTopKDecode.hcpp/tensorrt_llm/kernels/heuristic_topk.cuhcpp/tensorrt_llm/kernels/indexerTopK.cucpp/tensorrt_llm/thop/IndexerTopKOp.cppexamples/longbench/eval_longbench_v1.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/model_config.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/thop/parallel/test_indexer_topk.py
|
/bot run |
2 similar comments
|
/bot run |
|
/bot run |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #39861 [ run ] triggered by Bot. Commit: |
|
PR_Github #39861 [ run ] completed with state
|
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #39901 [ run ] triggered by Bot. Commit: |
|
PR_Github #39901 [ run ] completed with state
|
3ef3244 to
5f02778
Compare
|
/bot run |
…e indexer Enable the heuristic TopK decode path by persisting each layer's previous-step TopK indices and passing them as pre_idx hints to indexer_topk_decode. Consecutive decode steps have nearly identical attention patterns, so prior indices bootstrap a better initial threshold and reduce interpolation iterations in the heuristic kernel. Key changes: - Add enable_heuristic_topk config field (default False) to DeepSeekSparseAttentionConfig for opt-in activation. - Per-layer lazy-allocated buffers in Indexer for TopK index persistence with request-id-based slot management and stale-request cleanup. - Save prefill last-token TopK as seed for first decode step; save decode last-MTP-position TopK for subsequent steps. - Fix MTP preIdx row indexing in heuristic kernel (rowIdx/next_n). - Relax pre_idx size check in thop binding for per-request shape. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
Two fixes for the heuristic TopK decode path: 1. Add +1 offset to pre_idx in _gather_prev_topk_for_decode: the saved TopK came from a query at position P (last MTP pos of previous step), while the current step's first query is at P+1. Shifting by +1 preserves relative distances under RoPE for a more accurate initial threshold in the heuristic kernel. 2. Propagate enable_heuristic_topk through model_config.py: the DeepSeekV32 config builder was reconstructing DeepSeekSparseAttention Config without forwarding enable_heuristic_topk from the user's sparse_attention_config, causing the field to always default to False. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
…uffers Redesign the heuristic TopK pre_idx management to be fully CUDA Graph compatible. Replace the old Indexer-side dict/slot approach (which used .item() and torch.tensor() forbidden during graph capture) with pre-allocated metadata buffers using a feedback loop pattern. Key changes: - Per-layer 3D buffer heuristic_prev_topk on metadata, allocated via get_empty(capture_graph=True) for stable addresses across replays. - Shared staging buffer for +1 RoPE offset, all in-place ops. - Each graph replay's write becomes the next replay's read. - Heuristic now works with CUDA Graphs enabled (no need to disable). - Refine heuristic_topk.cuh kernel for improved convergence. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
…aluation Extend eval_longbench_v1.py to support DeepSeek DSA sparse attention alongside existing RocketKV, with heuristic TopK and MTP speculative decoding options. Key changes: - Add --dsa_sparse flag as mutually exclusive with --rocket_sparse - Add --enable_heuristic_topk for DSA heuristic TopK pre_idx reuse - Add --mtp for MTP speculative decoding layers - Update usage examples for both RocketKV and DSA workflows Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
…tions
Three kernel-level optimizations (OPT5/OPT6/OPT7) to reduce latency of
the single-CTA heuristic TopK on Blackwell (sm_100):
OPT5 — Skip Phase 3 blockCountGE re-scan when Phase 2 already confirmed
the candidate count is in [TOP_K, MAX_CANDIDATES] (done==1). Eliminates
one full N-element scan in the common fast-convergence case.
OPT6 — Increase NUM_BINS from 256 to 2048 and replace the serial O(2048)
K-th-bin scan with a two-step parallel search: each warp accumulates its
128-bin slice (NUM_BINS/NUM_WARPS), then thread 0 locates the target
warp in 16 steps and one warp-lane scans 128 bins. Serial depth drops
from 2048 to 144 steps. Shared memory grows from ~50 KB to ~59 KB; the
histogram is now 8 KB instead of 1 KB.
OPT7 — Cache per-thread element counts into smem->per_thread_counts[]
inside blockCountGE. Phase 3 prefix-sum reuses these cached values
instead of repeating the N-scan, saving one full global-memory pass.
Also add distribution-parameterised correctness tests for the heuristic
indexer_topk_decode path in test_indexer_topk.py:
- Four logit families: beta (bounded), logistic (heavy-tailed), lognorm
(positively skewed), weibull_min (right-skewed extreme-value)
- MTP correlation: consecutive rows within each batch element share tail
logits (next_n up to 3)
- pre_idx accuracy sweep: success_ratio in {0.5, 0.9}
- Tolerance set to full_range/256 to accept histogram-bin boundary ties
Made-with: claude-4.6-opus-high
Signed-off-by: longcheng-nv <[email protected]>
…flicts Pass caller-owned scratch buffer through heuristic TopK pipeline to eliminate cudaMallocAsync/cudaFreeAsync inside the kernel, enabling CUDA Graph capture. Also resolves stash merge conflicts in test file and adds seq_lens clamping for valid row lengths with next_n > 1. Signed-off-by: longcheng-nv <[email protected]> Made-with: claude-4.6-opus-high
…pK tests Add docstrings to all test functions and helpers to improve docstring coverage. Remove broken duplicate test_indexer_topk_decode_dist that referenced undefined names (dtype, run_fn). Suppress F811 on intentional CuTE DSL test redefinitions with different parametrize configurations. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
…n test_dsa_indexer The DSA backend accesses sparse_attention_config.enable_heuristic_topk but the mock SparseAttentionConfig in test_dsa_indexer.py was missing this field, causing all 76 DSA indexer tests to fail with AttributeError. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
Add one-line docstrings to all public functions, classes, and methods across the 5 Python files modified by this PR to satisfy the CI docstring coverage gate (was 41%, now ~100%). Signed-off-by: longcheng-nv <[email protected]> Made-with: claude-4.6-opus-high
Without heuristic_scratch the call silently falls back to the radix kernel. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
The heuristic TopK kernel requires stride0 divisible by 4 for float4 loads. In production fp8_paged_mqa_logits always returns tensors with stride aligned to 256, but the test helpers created logits with arbitrary max_len from random seq_lens, causing assertion failures when max_len was not a multiple of 4. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
…duplicate allocation - Remove duplicate `heuristic_scratch_values` allocation from `create_expanded_buffers()` (already created in `__post_init__`); move resize logic into `update_spec_dec_param()` where shape changes. - Eliminate `heuristic_pre_idx_staging` buffer entirely: pass `heuristic_prev_topk` directly as `pre_idx` to the kernel. - Move +1 temporal offset into the C++ kernel (`preIdxOffset = (rowIdx % next_n) + 1`), removing two Python tensor ops (copy_ and += 1) from the CUDA Graph captured region per layer per decode step. Addresses review comments from lfr-0531 on PR NVIDIA#12385. Signed-off-by: longcheng-nv <[email protected]>
…on pre-Blackwell - Add `get_sm_version() >= 100` guard to `enable_heuristic_topk` in both `DSAtrtllmAttentionMetadata` and `Indexer`, so the heuristic path silently falls back to radix sort on Hopper and older architectures. - Add `@skip_pre_blackwell` decorator to `test_indexer_topk_decode_dist` in `test_indexer_topk.py` and `test_indexer_topk_dist.py`, since the heuristic kernel is only validated for Blackwell GPUs. Fixes CI failure on DGX_H100: test_indexer_topk_decode_dist[8192-2048-2-128-logistic_m0.47_s1.46-0.5]. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
…eedback Revert docstring-only changes in cpp_custom_ops.py, model_config.py, and llm_args.py that were added in eda0351 but are outside the scope of this PR. Docstrings in files directly modified by this feature (dsa.py and test_dsa_indexer.py) are retained. Addresses feedback from lfr-0531 and kaiyux on PR NVIDIA#12385. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
… and multi-GPU Add runtime topK == TOP_K validation, compile-time static_assert blocking non-32-bit index types, remove process-wide static flag in favor of unconditional cudaFuncSetAttribute, and drop unsafe reinterpret_casts. Signed-off-by: Long Cheng <[email protected]> Signed-off-by: longcheng-nv <[email protected]>
Add a parametrized variant to TestDeepSeekV32.test_fp8_blockscale that exercises the end-to-end heuristic TopK decode path with MTP (next_n=1). Gated to Blackwell (SM >= 100). Validates accuracy on MMLU and GSM8K with enable_heuristic_topk=True through the LLM API. Made-with: claude-4.6-opus-high Signed-off-by: longcheng-nv <[email protected]>
386cd7b to
5f59b99
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41297 [ run ] triggered by Bot. Commit: |
|
PR_Github #41297 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41536 [ run ] triggered by Bot. Commit: |
|
PR_Github #41536 [ run ] completed with state |
…Sparse Attention (NVIDIA#12385) Signed-off-by: longcheng-nv <[email protected]>
…Sparse Attention (NVIDIA#12385) Signed-off-by: longcheng-nv <[email protected]>
Summary
heuristic_topk.cuh)enable_heuristic_topk=False); opt-in viaDeepSeekSparseAttentionConfig(enable_heuristic_topk=True)Commit Breakdown (11 commits)
__noinline__device function patternlaunch_boundsand restore sort thresholdKey Files
cpp/tensorrt_llm/kernels/heuristic_topk.cuhcpp/tensorrt_llm/kernels/heuristicTopKDecode.cucpp/tensorrt_llm/kernels/indexerTopK.cutensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/model_config.pyenable_heuristic_topkconfig propagationtensorrt_llm/llmapi/llm_args.pyDeepSeekSparseAttentionConfig.enable_heuristic_topkuser optionAPI
Enable via
DeepSeekSparseAttentionConfig(enable_heuristic_topk=True)inLlmArgs.Test plan
pytest tests/unittest/_torch/thop/parallel/test_indexer_topk.py(distribution-parameterized correctness tests)torch.topkacross all test configurationsAccuracy Evaluation (DeepSeek V3.2 NVFP4 on B200, 8×TP/EP, DSA sparse attention)
MMLU (14,042 questions)
Delta: 0.00 — accuracy-neutral.
GSM8K (5-shot)
Delta: +0.13 — accuracy-neutral (within noise).
GPQA-Diamond (CoT zero-shot)
Delta: −0.67 — within GPQA-Diamond's ~3pt stderr.
LongBench v1 (8 runs, MTP=1)
Delta: −0.33 — accuracy-neutral (< 0.75%).
LongBench v2 (215 questions, medium context length, MTP=1)
Delta: −0.46% — LongBench v2 has high per-run variance (215 questions, range 47–52% even for baseline). With 5 matched runs per config, the delta is within expected noise.
Summary
All benchmarks are accuracy-neutral within their respective noise margins. MMLU and GSM8K (large-N, deterministic) show effectively zero impact. GPQA and LongBench (small-N, high variance) show deltas within expected run-to-run noise.
Performance Results (B200)
The heuristic TopK micro-kernel (
heuristicTopKMultiRowKernel) is benchmarked against the default radix-sort path (topKPerRowDecode) on single-row logits input (one CTA per row). All benchmark runs first verify output indices matchtorch.topkexactly.Random input (norm-similar distribution, y = 1 + 0.1 * N[0,1])
Heuristic kernel wins at N >= 16K; at shorter sequences, the radix sort is faster (expected — the heuristic's histogram overhead dominates at small N, which is why the dispatch threshold gates it).
Realistic input (DeepSeek V3.2 decode logits, SWE-Bench ISL/OSL=64K/2K)
Profiled across 9 layers x 17 decode steps (N ~ 68.7K-70.7K):
The heuristic kernel achieves 1.81x average speedup over the radix sort on realistic DeepSeek V3.2 decode workloads at typical sequence lengths (~70K).