[None][feat] Mamba optimization and mixed quantization support for nemotron-h#11972
Conversation
|
@Wanli-Jiang I think to test, we can add an extra line in the dockerfile here to install flashinfer nightly - |
6b6516f to
9761690
Compare
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis pull request introduces support for Mamba SSM stochastic rounding by adding a new configuration field that flows from CLI arguments through config classes to the Mamba2 mixer implementation. The change includes conditional routing logic between FlashInfer and native execution paths based on hardware capabilities and dtype constraints. The FlashInfer dependency is updated to an URL-based installation reference. Changes
Sequence DiagramsequenceDiagram
actor User
participant CLI as CLI Parser
participant KvCache as KvCacheConfig
participant Loader as ModelLoader
participant QuantCfg as QuantConfig
participant Mixer as Mamba2Mixer
User->>CLI: --mamba_ssm_stochastic_rounding flag
CLI->>KvCache: args.mamba_ssm_stochastic_rounding
KvCache->>Loader: kv_cache_config.mamba_ssm_stochastic_rounding
Loader->>QuantCfg: validate_and_set_mamba_ssm_cache_dtype()
QuantCfg->>QuantCfg: Set mamba_ssm_stochastic_rounding
QuantCfg->>Mixer: config.mamba_ssm_stochastic_rounding
Mixer->>Mixer: Check head_dim in [64, 128]?
Mixer->>Mixer: Check dtype == float16?
Mixer->>Mixer: Check FlashInfer available?
alt All conditions met
Mixer->>Mixer: _use_stochastic_rounding = True
Mixer->>Mixer: Add rand_seed to kwargs
else Conditions not met
Mixer->>Mixer: _use_stochastic_rounding = False
Mixer->>Mixer: Emit warning
end
alt _use_flashinfer enabled
Mixer->>Mixer: Route to FlashInfer path
else
Mixer->>Mixer: Route to native implementation
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
37-52:⚠️ Potential issue | 🟡 MinorFail fast when stochastic rounding resolves to a non-FP16 cache dtype.
"auto"is only resolved here, but the new flag is copied through unconditionally. If the resolved Mamba SSM cache dtype ends up as BF16/FP32, the config still carries an unusable stochastic-rounding request deeper into the runtime instead of rejecting it at the first point where the actual dtype is known.Suggested fix
def validate_and_set_mamba_ssm_cache_dtype( config: ModelConfig, mamba_ssm_cache_dtype: str, mamba_ssm_stochastic_rounding: bool = False) -> None: @@ - config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype - config.quant_config.mamba_ssm_stochastic_rounding = mamba_ssm_stochastic_rounding + config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + if mamba_ssm_stochastic_rounding and mamba_ssm_cache_dtype != torch.float16: + raise ValueError( + "kv_cache_config.mamba_ssm_stochastic_rounding requires " + 'kv_cache_config.mamba_ssm_cache_dtype="float16"' + ) + config.quant_config.mamba_ssm_stochastic_rounding = mamba_ssm_stochastic_rounding🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/model_loader.py` around lines 37 - 52, In validate_and_set_mamba_ssm_cache_dtype, after resolving mamba_ssm_cache_dtype (via str_dtype_to_torch or config.pretrained_config.torch_dtype), immediately check if mamba_ssm_stochastic_rounding is True and the resolved dtype is not torch.float16 (FP16); if so, raise a ValueError (or similar) rejecting the incompatible combination instead of silently storing it on config.quant_config.mamba_ssm_stochastic_rounding; otherwise continue to set config.quant_config.mamba_ssm_cache_dtype and mamba_ssm_stochastic_rounding as before. Ensure you reference the resolved mamba_ssm_cache_dtype and the boolean mamba_ssm_stochastic_rounding within validate_and_set_mamba_ssm_cache_dtype (and use ModelConfig/quant_config fields) so the check occurs before writing into config.quant_config.tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate copyright year to include 2026.
The copyright header currently shows
2022-2024, but this file has meaningful modifications in 2026. As per coding guidelines, the copyright header should reflect the year of the latest meaningful modification.Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` at line 1, Update the file copyright header to reflect the latest modification year: change the existing "2022-2024" string in the top-of-file comment to "2022-2026" so the header reads "Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES"; locate the header in tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (top file comment) and perform the replacement while preserving the SPDX and surrounding comment formatting.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
444-453: Minor inconsistency:dt_softplusdiffers between MTP and non-MTP paths.The MTP path uses
dt_softplus=True(line 405) while the non-MTP path usesdt_softplus=self.delta_softplus(line 449). If this is intentional for speculative decoding behavior, consider adding a brief comment explaining why MTP always usesTrue.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` around lines 444 - 453, The dt_softplus flag is inconsistent between the MTP path (where dt_softplus=True is hard-coded) and the non-MTP path (where dt_softplus=self.delta_softplus) around the selective_state_update call; either make them consistent or document the intentional difference. Locate the MTP branch that builds ssu_kwargs with dt_softplus=True and the non-MTP branch that sets dt_softplus=self.delta_softplus (used when calling selective_state_update / selective_state_update in mamba2_mixer) and add a short inline comment explaining why MTP forces True for speculative decoding (or change the MTP assignment to use self.delta_softplus if it should match behavior) so the difference is explicit and not surprising.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Line 1: Update the file copyright header to reflect the latest modification
year: change the existing "2022-2024" string in the top-of-file comment to
"2022-2026" so the header reads "Copyright (c) 2022-2026 NVIDIA CORPORATION &
AFFILIATES"; locate the header in
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (top file comment) and perform
the replacement while preserving the SPDX and surrounding comment formatting.
In `@tensorrt_llm/_torch/pyexecutor/model_loader.py`:
- Around line 37-52: In validate_and_set_mamba_ssm_cache_dtype, after resolving
mamba_ssm_cache_dtype (via str_dtype_to_torch or
config.pretrained_config.torch_dtype), immediately check if
mamba_ssm_stochastic_rounding is True and the resolved dtype is not
torch.float16 (FP16); if so, raise a ValueError (or similar) rejecting the
incompatible combination instead of silently storing it on
config.quant_config.mamba_ssm_stochastic_rounding; otherwise continue to set
config.quant_config.mamba_ssm_cache_dtype and mamba_ssm_stochastic_rounding as
before. Ensure you reference the resolved mamba_ssm_cache_dtype and the boolean
mamba_ssm_stochastic_rounding within validate_and_set_mamba_ssm_cache_dtype (and
use ModelConfig/quant_config fields) so the check occurs before writing into
config.quant_config.
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Around line 444-453: The dt_softplus flag is inconsistent between the MTP path
(where dt_softplus=True is hard-coded) and the non-MTP path (where
dt_softplus=self.delta_softplus) around the selective_state_update call; either
make them consistent or document the intentional difference. Locate the MTP
branch that builds ssu_kwargs with dt_softplus=True and the non-MTP branch that
sets dt_softplus=self.delta_softplus (used when calling selective_state_update /
selective_state_update in mamba2_mixer) and add a short inline comment
explaining why MTP forces True for speculative decoding (or change the MTP
assignment to use self.delta_softplus if it should match behavior) so the
difference is explicit and not surprising.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0068c16a-62e8-4517-af61-f1083da999f5
📒 Files selected for processing (6)
examples/llm-api/quickstart_advanced.pyrequirements.txttensorrt_llm/_torch/modules/mamba/mamba2_mixer.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/models/modeling_utils.py
|
PR_Github #38209 [ run ] triggered by Bot. Commit: |
9761690 to
c49539c
Compare
|
PR_Github #38209 [ run ] completed with state
|
…r Mamba SSM cache Signed-off-by: Wanli Jiang <[email protected]>
Signed-off-by: Wanli Jiang <[email protected]>
Signed-off-by: Izzy Putterman <[email protected]>
Superjomn
left a comment
There was a problem hiding this comment.
LGTM on the llmapi changes.
|
/bot run --disable-fail-fast |
|
/bot run --stage-list "Build-Docker-Images" |
|
PR_Github #38368 [ run ] triggered by Bot. Commit: |
|
PR_Github #38369 [ run ] triggered by Bot. Commit: |
|
PR_Github #38368 [ run ] completed with state |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot run --disable-fail-fast --extra-stage "Build-Docker-Images" |
|
PR_Github #38372 [ run ] triggered by Bot. Commit: |
QiJune
left a comment
There was a problem hiding this comment.
LGTM for the API change.
|
/bot run --only-multi-gpu-test --disable-fail-fast |
|
PR_Github #38437 [ run ] triggered by Bot. Commit: |
|
PR_Github #38437 [ run ] completed with state
|
Signed-off-by: Wanli Jiang <[email protected]>
Signed-off-by: Wanli Jiang <[email protected]>
|
/bot run --stage-list "DGX_B200-4_GPUs-PyTorch-3,DGX_H100-4_GPUs-PyTorch-Others-2" --disable-fail-fast |
|
PR_Github #38524 [ run ] triggered by Bot. Commit: |
|
PR_Github #38524 [ run ] completed with state |
|
/bot skip --comment “Skipped since the duplicated PR12072 is passed CI testing" |
|
PR_Github #38584 [ ] completed with state |
…motron-h (NVIDIA#11972) Signed-off-by: Wanli Jiang <[email protected]> Signed-off-by: Izzy Putterman <[email protected]> Co-authored-by: Izzy Putterman <[email protected]>
…motron-h (NVIDIA#11972) Signed-off-by: Wanli Jiang <[email protected]> Signed-off-by: Izzy Putterman <[email protected]> Co-authored-by: Izzy Putterman <[email protected]>
Features:
TODO:
Summary by CodeRabbit
New Features
Chores
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.