Skip to content

[None][fix] skip inference_mode() when torch.compile=True for gemma3 fp8#12367

Merged
amukkara merged 3 commits into
NVIDIA:mainfrom
amukkara:tc-fp8-gemma
Apr 11, 2026
Merged

[None][fix] skip inference_mode() when torch.compile=True for gemma3 fp8#12367
amukkara merged 3 commits into
NVIDIA:mainfrom
amukkara:tc-fp8-gemma

Conversation

@amukkara
Copy link
Copy Markdown
Collaborator

@amukkara amukkara commented Mar 19, 2026

Summary by CodeRabbit

Release Notes

  • Refactor

    • Improved inference mode handling during torch compilation for Gemma3 models, enabling proper execution context switching when compiling or tracing models.
  • Tests

    • Added test coverage for FP8-quantized Gemma3 1B model inference accuracy validation with torch compilation enabled.

Description

  • Add conditional decorator that skips torch.inference_mode() when inside torch.compile dynamo trace.

    • During compilation, torch.dynamo traces through the model and encounters .t() (transpose) being called on a weight tensor that is under torch.inference_mode(), for example here. The version_counter is a PyTorch mechanism to track in-place mutations — but inference tensors don't support it. When torch.compile's backend tries to set it (as part of functionalization/graph capture), it throws a RuntimeError: Cannot set version_counter for inference tensor.
  • Set maybe_execute_in_parallel(.., disable_on_compile=True) in QKNormRopeAttention following the pattern of [https://nvbugs/6029220][fix] Disable multi-stream in maybe_execute_i… #12659

Test Coverage

Parameterize tests/integration/defs/accuracy/test_llm_api_pytorch.py::test_fp8_prequantized with torch_compile=[True, False]

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@amukkara amukkara changed the title [None][fix] torch compile for gemma3 fp8 [None][fix] skip inference_mode() when torch.compile=True for gemma3 fp8 Mar 19, 2026
@amukkara amukkara force-pushed the tc-fp8-gemma branch 2 times, most recently from 54ea19a to 1505b7d Compare March 20, 2026 17:19
@amukkara amukkara marked this pull request as ready for review March 20, 2026 19:08
@amukkara amukkara requested review from a team as code owners March 20, 2026 19:08
@amukkara amukkara requested review from brb-nv and liji-nv March 20, 2026 19:08
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

This pull request adds a new inference_mode_unless_compiling decorator to conditionally control PyTorch inference mode behavior in Gemma3 model forward methods. The decorator activates torch.inference_mode() only when the torch compiler is not actively compiling. Additionally, a new integration test validates FP8 quantization with torch compilation on the Gemma3 1B model, along with corresponding test list entries.

Changes

Cohort / File(s) Summary
Model Implementation
tensorrt_llm/_torch/models/modeling_gemma3.py
Added inference_mode_unless_compiling decorator that conditionally applies torch.inference_mode() based on compiler state, applied to embedding, attention, decoder layer, decoder model, and top-level causal LM forward methods, replacing previous @torch.inference_mode() decorators.
Integration Test
tests/integration/defs/accuracy/test_llm_api_pytorch.py
Added test_fp8_torch_compile test method to TestGemma3_1BInstruct class to validate FP8 quantization accuracy with torch compilation enabled using MMLU evaluation.
Test Lists
tests/integration/test_lists/qa/llm_function_core.txt, tests/integration/test_lists/test-db/l0_h100.yml
Registered new FP8 torch compilation test entry in QA and H100 test lists for automated test discovery and execution.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 6.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding a conditional decorator to skip torch.inference_mode() when torch.compile is enabled for Gemma3 FP8 model.
Description check ✅ Passed The PR description clearly explains the problem (version_counter error during torch.compile) and the solution (conditional decorator to skip inference_mode during dynamo tracing). Test coverage is adequately specified.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/models/modeling_gemma3.py`:
- Line 1: Add the required NVIDIA Apache-2.0 copyright/license header as the
very first lines of tensorrt_llm/_torch/models/modeling_gemma3.py (i.e., place
it before the existing executable statement "import functools"); include the
correct NVIDIA copyright statement with the year of latest meaningful
modification and the full Apache License 2.0 notice used across the repo so the
file conforms to the repository header policy.
- Around line 35-39: The compiled branch in the wrapper function (guarded by
torch.compiler.is_compiling()) directly calls func and thus loses the grad-off
semantics applied in the non-compiled branch via torch.inference_mode(); update
wrapper so that when torch.compiler.is_compiling() is True it executes func
inside a torch.no_grad() context (i.e., wrap the call to func in
torch.no_grad()) to ensure consistent gradient-disabled behavior between
compiled and non-compiled paths while keeping the existing
torch.inference_mode() for the non-compiled branch.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a89e8eed-7b9d-4d11-b09d-62255ca821af

📥 Commits

Reviewing files that changed from the base of the PR and between 68001ce and 1505b7d.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/modeling_gemma3.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/test-db/l0_h100.yml

Comment thread tensorrt_llm/_torch/models/modeling_gemma3.py Outdated
Comment thread tensorrt_llm/_torch/models/modeling_gemma3.py Outdated
@brb-nv
Copy link
Copy Markdown
Collaborator

brb-nv commented Mar 31, 2026

Couple of questions based on MR description:

  1. What is an inference tensor? Is it an activation tensor (as opposed to weight tensor)?
  2. Can we parameterize existing test instead of adding a brand new one?

Copy link
Copy Markdown
Collaborator

@brb-nv brb-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Minor comments.

@amukkara
Copy link
Copy Markdown
Collaborator Author

amukkara commented Apr 3, 2026

Couple of questions based on MR description:

  1. What is an inference tensor? Is it an activation tensor (as opposed to weight tensor)?
  2. Can we parameterize existing test instead of adding a brand new one?
  1. It applies to both activations and weights. no tracking of gradients (for backward pass in training) and versions counters. gives better performance in eager mode, but conflicts with torch.compile backend.
  2. Done

@amukkara amukkara force-pushed the tc-fp8-gemma branch 2 times, most recently from 4cfe756 to 8dbcb50 Compare April 9, 2026 21:50
@amukkara amukkara requested a review from a team as a code owner April 9, 2026 21:50
@amukkara amukkara requested a review from HuiGao-NV April 9, 2026 21:50
@amukkara
Copy link
Copy Markdown
Collaborator Author

amukkara commented Apr 9, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42585 [ run ] triggered by Bot. Commit: 8dbcb50 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42585 [ run ] completed with state SUCCESS. Commit: 8dbcb50
/LLM/main/L0_MergeRequest_PR pipeline #33311 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42648 [ run ] triggered by Bot. Commit: 8dbcb50 Link to invocation

Copy link
Copy Markdown
Collaborator

@HuiGao-NV HuiGao-NV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread tensorrt_llm/_torch/models/modeling_gemma3.py Outdated
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42648 [ run ] completed with state SUCCESS. Commit: 8dbcb50
/LLM/main/L0_MergeRequest_PR pipeline #33360 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot run

Signed-off-by: Anurag Mukkara <[email protected]>
@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot run

@amukkara amukkara requested a review from a team as a code owner April 10, 2026 18:23
@amukkara amukkara requested a review from lfr-0531 April 10, 2026 18:23
@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot run

@amukkara amukkara enabled auto-merge (squash) April 10, 2026 18:26
@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42736 [ kill ] completed with state SUCCESS. Commit: 4d6da83
Successfully killed previous jobs for commit 4d6da83

Link to invocation

@NVIDIA NVIDIA deleted a comment from tensorrt-cicd Apr 10, 2026
@NVIDIA NVIDIA deleted a comment from tensorrt-cicd Apr 10, 2026
@NVIDIA NVIDIA deleted a comment from tensorrt-cicd Apr 10, 2026
@NVIDIA NVIDIA deleted a comment from tensorrt-cicd Apr 10, 2026
@NVIDIA NVIDIA deleted a comment from tensorrt-cicd Apr 10, 2026
@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42737 [ run ] triggered by Bot. Commit: 4d6da83 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42737 [ run ] completed with state SUCCESS. Commit: 4d6da83
/LLM/main/L0_MergeRequest_PR pipeline #33419 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@amukkara
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42764 [ run ] triggered by Bot. Commit: 4d6da83 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42764 [ run ] completed with state SUCCESS. Commit: 4d6da83
/LLM/main/L0_MergeRequest_PR pipeline #33441 completed with status: 'SUCCESS'

CI Report

Link to invocation

@amukkara amukkara merged commit ce80c14 into NVIDIA:main Apr 11, 2026
5 checks passed
@amukkara amukkara deleted the tc-fp8-gemma branch April 30, 2026 19:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants