[TRTLLM-10061][feat] Add stride support for conv1d and fused_sigmoid_gating_delta_rule_update#12442
Conversation
📝 WalkthroughWalkthroughThese changes modify initial-state tensor handling in GPU operations. A C++ validation check switches from stride to contiguity requirements. A new conditional guard decorator is introduced to selectively apply contiguity enforcement. The fused sigmoid gating kernel is updated to accept initial-state layout parameters and perform bounds checking. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/thop/causalConv1dOp.cpp`:
- Line 267: The TORCH_CHECK invocation validating conv_state_indices' scalar
type is missing a terminating semicolon causing a compile error; locate the line
containing TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) and
add a semicolon at the end of that statement so it reads as a complete statement
(ensure no other tokens are accidentally altered around the TORCH_CHECK or
conv_state_indices reference).
In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py`:
- Around line 204-206: The else branch for when initial_state_source is None
sets s_h0_0 but fails to define slot_num, causing a NameError later when
constructing h0_dim0; update the else branch in the
fused_sigmoid_gating_recurrent logic to also initialize slot_num (e.g., set
slot_num = 0 or the appropriate default) whenever initial_state_source is None
so downstream use in h0_dim0 and any other uses of slot_num is valid; locate the
block around s_h0_0 and initial_state_source in
fused_sigmoid_gating_recurrent.py to make this change.
In `@tensorrt_llm/_torch/modules/fla/utils.py`:
- Around line 179-181: The positional-argument exclusion is broken because the
generator at contiguous_args compares tensor objects to strings (i in
exclude_args) instead of checking parameter names; update the logic in the
contiguous_args generator (and its surrounding function) to map positional
indices to parameter names using inspect.signature (or similar) and check the
parameter name against exclude_args before calling .contiguous(); i.e.,
determine each positional arg's parameter name from the function signature, and
only call i.contiguous() when the parameter name is not in exclude_args (leave
keyword exclusion using kwargs as-is).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: cda05325-2eaf-4467-985f-c9ae20ba2641
📒 Files selected for processing (3)
cpp/tensorrt_llm/thop/causalConv1dOp.cpptensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.pytensorrt_llm/_torch/modules/fla/utils.py
4cea670 to
2f1fed7
Compare
|
/bot run |
|
PR_Github #39903 [ run ] triggered by Bot. Commit: |
…gating_delta_rule_update - Replace hardcoded stride(0)==1 check with is_contiguous() in causalConv1dUpdate - Use explicit stride parameter (s_h0_0) instead of hardcoded HV*K*V for h0_source indexing in the triton kernel, enabling non-contiguous initial_state_source layouts - Add int64 cast to prevent int32 overflow in index computation - Add device_assert bounds check for h0_source store - Add input_guard_exclude decorator to skip contiguous() on selected tensor arguments Signed-off-by: Xiwen Yu <[email protected]>
2f1fed7 to
8620daa
Compare
|
/bot run |
|
PR_Github #39912 [ run ] triggered by Bot. Commit: |
Removed bounds checking assertion for h0_source index. Signed-off-by: xiweny <[email protected]>
|
/bot run |
|
PR_Github #39923 [ run ] triggered by Bot. Commit: |
|
PR_Github #39923 [ run ] completed with state |
…gating_delta_rule_update (NVIDIA#12442) Signed-off-by: Xiwen Yu <[email protected]> Signed-off-by: xiweny <[email protected]>
Summary
stride(0)==1check withis_contiguous()incausalConv1dUpdates_h0_0) instead of hardcodedHV*K*Vforh0_sourceindexing in the fused_sigmoid_gating_delta_rule_update triton kernel, enabling non-contiguousinitial_state_sourcelayoutsinput_guard_excludedecorator to selectively skipcontiguous()on specified tensor argumentsTest plan
Summary by CodeRabbit