Skip to content

[TRTLLM-10061][feat] Add stride support for conv1d and fused_sigmoid_gating_delta_rule_update#12442

Merged
VALLIS-NERIA merged 2 commits into
NVIDIA:mainfrom
VALLIS-NERIA:user/xiweny/stride_support
Mar 23, 2026
Merged

[TRTLLM-10061][feat] Add stride support for conv1d and fused_sigmoid_gating_delta_rule_update#12442
VALLIS-NERIA merged 2 commits into
NVIDIA:mainfrom
VALLIS-NERIA:user/xiweny/stride_support

Conversation

@VALLIS-NERIA
Copy link
Copy Markdown
Collaborator

@VALLIS-NERIA VALLIS-NERIA commented Mar 23, 2026

Summary

  • Replace hardcoded stride(0)==1 check with is_contiguous() in causalConv1dUpdate
  • Use explicit stride parameter (s_h0_0) instead of hardcoded HV*K*V for h0_source indexing in the fused_sigmoid_gating_delta_rule_update triton kernel, enabling non-contiguous initial_state_source layouts
  • Add input_guard_exclude decorator to selectively skip contiguous() on specified tensor arguments
  • Add int64 cast and device_assert bounds check for safer index computation

Test plan

  • Existing unit tests pass
  • Integration tests with hybrid linear models using non-contiguous state layouts

Summary by CodeRabbit

  • Bug Fixes & Optimizations
    • Improved tensor validation logic for convolutional operations to enhance compatibility
    • Enhanced bounds checking and memory safety in kernel operations
    • Added support for flexible tensor layout handling to improve robustness

@VALLIS-NERIA VALLIS-NERIA requested a review from a team as a code owner March 23, 2026 07:58
@VALLIS-NERIA VALLIS-NERIA requested a review from yuxianq March 23, 2026 07:58
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 23, 2026

📝 Walkthrough

Walkthrough

These changes modify initial-state tensor handling in GPU operations. A C++ validation check switches from stride to contiguity requirements. A new conditional guard decorator is introduced to selectively apply contiguity enforcement. The fused sigmoid gating kernel is updated to accept initial-state layout parameters and perform bounds checking.

Changes

Cohort / File(s) Summary
Causal Convolution Validation
cpp/tensorrt_llm/thop/causalConv1dOp.cpp
Changed conv_state_indices_ validation from stride check (stride(0) == 1) to contiguity check (is_contiguous()).
Input Guard Decorator
tensorrt_llm/_torch/modules/fla/utils.py
Added new input_guard_exclude() decorator factory that selectively applies .contiguous() to tensor arguments while excluding specified parameter names.
Fused Sigmoid Gating Update
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
Updated kernel to accept initial-state layout parameters (s_h0_0, h0_dim0), added explicit int64 casts for index loading with overflow prevention, device-side bounds checking, and replaced input_guard with input_guard_exclude() on wrapper function.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 37.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the main changes: adding stride support for two functions (conv1d and fused_sigmoid_gating_delta_rule_update), which is directly reflected in the PR's core modifications.
Description check ✅ Passed The PR description covers the key changes with clear summaries but lacks detail on the Problem/Motivation section and contains incomplete test coverage items marked as to-do.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/thop/causalConv1dOp.cpp`:
- Line 267: The TORCH_CHECK invocation validating conv_state_indices' scalar
type is missing a terminating semicolon causing a compile error; locate the line
containing TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) and
add a semicolon at the end of that statement so it reads as a complete statement
(ensure no other tokens are accidentally altered around the TORCH_CHECK or
conv_state_indices reference).

In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py`:
- Around line 204-206: The else branch for when initial_state_source is None
sets s_h0_0 but fails to define slot_num, causing a NameError later when
constructing h0_dim0; update the else branch in the
fused_sigmoid_gating_recurrent logic to also initialize slot_num (e.g., set
slot_num = 0 or the appropriate default) whenever initial_state_source is None
so downstream use in h0_dim0 and any other uses of slot_num is valid; locate the
block around s_h0_0 and initial_state_source in
fused_sigmoid_gating_recurrent.py to make this change.

In `@tensorrt_llm/_torch/modules/fla/utils.py`:
- Around line 179-181: The positional-argument exclusion is broken because the
generator at contiguous_args compares tensor objects to strings (i in
exclude_args) instead of checking parameter names; update the logic in the
contiguous_args generator (and its surrounding function) to map positional
indices to parameter names using inspect.signature (or similar) and check the
parameter name against exclude_args before calling .contiguous(); i.e.,
determine each positional arg's parameter name from the function signature, and
only call i.contiguous() when the parameter name is not in exclude_args (leave
keyword exclusion using kwargs as-is).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cda05325-2eaf-4467-985f-c9ae20ba2641

📥 Commits

Reviewing files that changed from the base of the PR and between 7aa1383 and 4cea670.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/thop/causalConv1dOp.cpp
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
  • tensorrt_llm/_torch/modules/fla/utils.py

Comment thread cpp/tensorrt_llm/thop/causalConv1dOp.cpp Outdated
Comment thread tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
Comment thread tensorrt_llm/_torch/modules/fla/utils.py Outdated
@VALLIS-NERIA VALLIS-NERIA force-pushed the user/xiweny/stride_support branch from 4cea670 to 2f1fed7 Compare March 23, 2026 08:10
@VALLIS-NERIA
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39903 [ run ] triggered by Bot. Commit: 2f1fed7 Link to invocation

Comment thread tensorrt_llm/_torch/modules/fla/utils.py Outdated
Comment thread tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
Comment thread tensorrt_llm/_torch/modules/fla/utils.py Outdated
…gating_delta_rule_update

- Replace hardcoded stride(0)==1 check with is_contiguous() in
  causalConv1dUpdate
- Use explicit stride parameter (s_h0_0) instead of hardcoded
  HV*K*V for h0_source indexing in the triton kernel, enabling
  non-contiguous initial_state_source layouts
- Add int64 cast to prevent int32 overflow in index computation
- Add device_assert bounds check for h0_source store
- Add input_guard_exclude decorator to skip contiguous() on
  selected tensor arguments

Signed-off-by: Xiwen Yu <[email protected]>
@VALLIS-NERIA VALLIS-NERIA force-pushed the user/xiweny/stride_support branch from 2f1fed7 to 8620daa Compare March 23, 2026 08:54
@VALLIS-NERIA
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39912 [ run ] triggered by Bot. Commit: 8620daa Link to invocation

@VALLIS-NERIA VALLIS-NERIA requested a review from yuxianq March 23, 2026 10:20
Comment thread tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py Outdated
Removed bounds checking assertion for h0_source index.

Signed-off-by: xiweny <[email protected]>
@VALLIS-NERIA
Copy link
Copy Markdown
Collaborator Author

/bot run

@VALLIS-NERIA VALLIS-NERIA enabled auto-merge (squash) March 23, 2026 10:29
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39923 [ run ] triggered by Bot. Commit: 41f1b77 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39923 [ run ] completed with state SUCCESS. Commit: 41f1b77
/LLM/main/L0_MergeRequest_PR pipeline #31090 completed with status: 'SUCCESS'

CI Report

Link to invocation

@VALLIS-NERIA VALLIS-NERIA merged commit 74f2efb into NVIDIA:main Mar 23, 2026
4 of 5 checks passed
longcheng-nv pushed a commit to longcheng-nv/TensorRT-LLM that referenced this pull request Mar 31, 2026
…gating_delta_rule_update (NVIDIA#12442)

Signed-off-by: Xiwen Yu <[email protected]>
Signed-off-by: xiweny <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants